In [1]:
import torch
import torch.nn as nn
import torchvision
import PIL
import io
import json
import os
import numpy as np
import logging

In [2]:
class SegmentationModel(nn.Module):
    def __init__(self, seg_path='FoliageFixerModel/models/mobilenetv2.3'):
        super().__init__()
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
            print(torch.cuda.get_device_name(self.device))
        else:
            self.device = None
            print('GPU is not available')
        # Define your segmentation model here
        # self.segmentation = models.segmentation.__dict__["fcn_resnet50"](pretrained=True)
        self.segmentation = torch.load(seg_path, map_location=torch.device('cpu'))
        # Freeze the segmentation layers
        for param in self.segmentation.parameters():
            param.requires_grad = False

    def forward(self, x):
        device = self.device
        # Forward pass through the segmentation model
        # x = x/255.0
        # resize
        x = torchvision.transforms.Resize(size=(512,512))(x)
        if x.shape[1] == 4:
          # if batch_size is 1
          if x.shape[0] == 1:
            img = x.squeeze(0)
            # transpose to shape: 512, 512, 4
            img = np.transpose(img, (1,2,0))
            pil_image = PIL.Image.fromarray(img.numpy(), 'RGBA')
            rgb_image = pil_image.convert('RGB')
            rgb_array = np.asarray(rgb_image, dtype=np.float32)
            x = torch.from_numpy(np.transpose(rgb_array, (2,1,0)))
            x = x.unsqueeze(0)
        x = x.to(device=device)
        # segment image first
        outputs = self.segmentation(x)
        # Apply softmax activation function to the output
        probs = torch.softmax(outputs, dim=1)
        # Get the predicted labels
        _, labels = torch.max(probs, dim=1)

        disease_mask = (labels == 2).float()
        disease_mask = torch.unsqueeze(disease_mask, 1)
        healthy_mask = (labels == 1).float()
        healthy_mask = torch.unsqueeze(healthy_mask, 1)
        leaf_mask = disease_mask + healthy_mask

        disease = x * disease_mask
        leaf = x * leaf_mask
        return (leaf, disease)

class ClassificationModel(nn.Module):
    def __init__(self, num_classes=8):
        super().__init__()
        # Define your classification model here
        self.classification = torchvision.models.resnet18(pretrained=True)
        # Replace the last layer with a new layer that has num_classes outputs
        num_features = self.classification.fc.in_features
        self.classification.fc = nn.Linear(num_features, num_classes)
        # self.label_dict = train_set.class_to_idx
    
    def forward(self, x):
        # Forward pass through the classification model
        x = self.classification(x)
        return x

### Convert to ONNX

In [3]:
'''
Load segmentation model
'''
segmentation_model = SegmentationModel(seg_path='saved_seg_models/mobilenetv3')

'''
Load classification model
'''
classification_model = ClassificationModel()
weights_path = 'saved_seg_models/classification-v5.1_stateDict'
classification_model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
classification_model.train(mode=False)

GPU is not available


  from .autonotebook import tqdm as notebook_tqdm


ClassificationModel(
  (classification): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, aff

In [4]:
batch_size = 1
channels = 3
height = 512 #model resizes anyway
width = 512 #model resizes anyway
sample_input_seg = torch.rand((batch_size, channels, height, width))

In [5]:
onnx_model_path = 'onnx_model/segmodelv3.onnx'

torch.onnx.export(
    segmentation_model,                  # PyTorch Model
    sample_input_seg,                    # Input tensor
    onnx_model_path,        # Output file (eg. 'output_model.onnx')
    opset_version=12,       # Operator support version
    input_names=['input'],   # Input tensor name (arbitary)
    output_names=['output_leaf', 'output_disease'], # Output tensor name (arbitary)
    # dynamic_axes={"input": {2: "width", 3: "height"}}
)

  if x.shape[1] == 4:
  if h % output_stride != 0 or w % output_stride != 0:


verbose: False, log level: Level.ERROR



In [6]:
batch_size = 1
channels = 3
height = 512 #model resizes anyway
width = 512 #model resizes anyway
sample_input_class = torch.rand((batch_size, channels, height, width))

In [7]:
onnx_model_path = 'onnx_model/classmodelv5.1.onnx'

torch.onnx.export(
    classification_model,                  # PyTorch Model
    sample_input_class,                    # Input tensor
    onnx_model_path,        # Output file (eg. 'output_model.onnx')
    opset_version=12,       # Operator support version
    input_names=['input'],   # Input tensor name (arbitary)
    output_names=['output'], # Output tensor name (arbitary)
    # dynamic_axes={"input": {2: "width", 3: "height"}}
)

verbose: False, log level: Level.ERROR



#Verification

In [8]:
import onnx

In [9]:
# Load the ONNX model
model = onnx.load("onnx_model/classmodelv5.1.onnx")

# Check that the IR is well formed
onnx.checker.check_model(model)

# Print a Human readable representation of the graph
onnx.helper.printable_graph(model.graph)

'graph torch_jit (\n  %input[FLOAT, 1x3x512x512]\n) initializers (\n  %classification.fc.weight[FLOAT, 8x512]\n  %classification.fc.bias[FLOAT, 8]\n  %onnx::Conv_193[FLOAT, 64x3x7x7]\n  %onnx::Conv_194[FLOAT, 64]\n  %onnx::Conv_196[FLOAT, 64x64x3x3]\n  %onnx::Conv_197[FLOAT, 64]\n  %onnx::Conv_199[FLOAT, 64x64x3x3]\n  %onnx::Conv_200[FLOAT, 64]\n  %onnx::Conv_202[FLOAT, 64x64x3x3]\n  %onnx::Conv_203[FLOAT, 64]\n  %onnx::Conv_205[FLOAT, 64x64x3x3]\n  %onnx::Conv_206[FLOAT, 64]\n  %onnx::Conv_208[FLOAT, 128x64x3x3]\n  %onnx::Conv_209[FLOAT, 128]\n  %onnx::Conv_211[FLOAT, 128x128x3x3]\n  %onnx::Conv_212[FLOAT, 128]\n  %onnx::Conv_214[FLOAT, 128x64x1x1]\n  %onnx::Conv_215[FLOAT, 128]\n  %onnx::Conv_217[FLOAT, 128x128x3x3]\n  %onnx::Conv_218[FLOAT, 128]\n  %onnx::Conv_220[FLOAT, 128x128x3x3]\n  %onnx::Conv_221[FLOAT, 128]\n  %onnx::Conv_223[FLOAT, 256x128x3x3]\n  %onnx::Conv_224[FLOAT, 256]\n  %onnx::Conv_226[FLOAT, 256x256x3x3]\n  %onnx::Conv_227[FLOAT, 256]\n  %onnx::Conv_229[FLOAT, 256x1

In [10]:
# Load the ONNX model
model = onnx.load("onnx_model/segmodelv3.onnx")

# Check that the IR is well formed
onnx.checker.check_model(model)

# Print a Human readable representation of the graph
onnx.helper.printable_graph(model.graph)

"graph torch_jit (\n  %input[FLOAT, 1x3x512x512]\n) initializers (\n  %segmentation.segmentation_head.0.weight[FLOAT, 3x16x3x3]\n  %segmentation.segmentation_head.0.bias[FLOAT, 3]\n  %onnx::Conv_825[FLOAT, 32x3x3x3]\n  %onnx::Conv_826[FLOAT, 32]\n  %onnx::Conv_828[FLOAT, 32x1x3x3]\n  %onnx::Conv_829[FLOAT, 32]\n  %onnx::Conv_831[FLOAT, 16x32x1x1]\n  %onnx::Conv_832[FLOAT, 16]\n  %onnx::Conv_834[FLOAT, 96x16x1x1]\n  %onnx::Conv_835[FLOAT, 96]\n  %onnx::Conv_837[FLOAT, 96x1x3x3]\n  %onnx::Conv_838[FLOAT, 96]\n  %onnx::Conv_840[FLOAT, 24x96x1x1]\n  %onnx::Conv_841[FLOAT, 24]\n  %onnx::Conv_843[FLOAT, 144x24x1x1]\n  %onnx::Conv_844[FLOAT, 144]\n  %onnx::Conv_846[FLOAT, 144x1x3x3]\n  %onnx::Conv_847[FLOAT, 144]\n  %onnx::Conv_849[FLOAT, 24x144x1x1]\n  %onnx::Conv_850[FLOAT, 24]\n  %onnx::Conv_852[FLOAT, 144x24x1x1]\n  %onnx::Conv_853[FLOAT, 144]\n  %onnx::Conv_855[FLOAT, 144x1x3x3]\n  %onnx::Conv_856[FLOAT, 144]\n  %onnx::Conv_858[FLOAT, 32x144x1x1]\n  %onnx::Conv_859[FLOAT, 32]\n  %onnx::C

### Convert to Tensorflow

In [11]:
import onnx

seg_onnx_model = onnx.load('onnx_model/segmodelv3.onnx')
class_onnx_model = onnx.load('onnx_model/classmodelv5.1.onnx')

In [12]:
from onnx_tf.backend import prepare

seg_tf_rep = prepare(seg_onnx_model)
class_tf_rep = prepare(class_onnx_model)

2023-09-08 15:51:18.623182: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [13]:
seg_tf_rep.export_graph('tf_model/segmodelv3')
class_tf_rep.export_graph('tf_model/classmodelv5.1')

INFO:tensorflow:Assets written to: tf_model/segmodelv3/assets


INFO:tensorflow:Assets written to: tf_model/segmodelv3/assets
INFO:absl:Function `__call__` contains input name(s) x, y with unsupported characters which will be renamed to transpose_62_x, add_19_y in the SavedModel.
INFO:absl:Found untraced functions such as gen_tensor_dict while saving (showing 1 of 1). These functions will not be directly callable after loading.


INFO:tensorflow:Assets written to: tf_model/classmodelv5.1/assets


INFO:tensorflow:Assets written to: tf_model/classmodelv5.1/assets
INFO:absl:Writing fingerprint to tf_model/classmodelv5.1/fingerprint.pb


In [14]:
import tensorflow as tf

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model('tf_model/segmodelv3')
tflite_model = converter.convert()

# Save the model
with open('tf_lite/segmodelv3.tflite', 'wb') as f:
    f.write(tflite_model)

    # Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model('tf_model/classmodelv5.1')
tflite_model = converter.convert()

# Save the model
with open('tf_lite/classmodelv5.1.tflite', 'wb') as f:
    f.write(tflite_model)

2023-09-08 15:51:54.005038: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
2023-09-08 15:51:54.005104: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2023-09-08 15:51:54.005974: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: tf_model/segmodelv3
2023-09-08 15:51:54.018332: I tensorflow/cc/saved_model/reader.cc:91] Reading meta graph with tags { serve }
2023-09-08 15:51:54.018390: I tensorflow/cc/saved_model/reader.cc:132] Reading SavedModel debug info (if present) from: tf_model/segmodelv3
2023-09-08 15:51:54.035874: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:375] MLIR V1 optimization pass is not enabled
2023-09-08 15:51:54.040191: I tensorflow/cc/saved_model/loader.cc:231] Restoring SavedModel bundle.
2023-09-08 15:51:54.151737: I tensorflow/cc/saved_model/loader.cc:215] Running initialization op on SavedModel bundle at path: tf_model/segmodelv

### Classification Inference

In [26]:
import numpy as np
import tensorflow as tf
import torch

label_dict = {
    'Bacterial Spot': 0, 
    'Early Blight': 1, 
    'Healthy': 2, 
    'Late Blight': 3,
    'Leaf Mold': 4, 
    'Septoria Leaf Spot': 5, 
    'Tomato Mosaic Virus': 6, 
    'Yellow Leaf Curl Virus': 7
    }

def get_classification(outputs):
    outputs = torch.from_numpy(outputs)
    probabilities = torch.softmax(outputs, dim=1)
    # this id may not be related to database ids
    predicted_class_id = torch.argmax(probabilities, dim=1)
    predicted_class = get_class_from_id(predicted_class_id.cpu().numpy()[0])
    return predicted_class

def get_class_from_id(id):
  label = list(label_dict.keys())[list(label_dict.values()).index(id)]
  return label

# Load the TFLite model and allocate tensors
interpreter = tf.lite.Interpreter(model_path="tf_lite/classmodelv5.1.tflite")
interpreter.allocate_tensors()

# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test the model on random input data
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
import torchvision
input_data = torchvision.io.read_image('dataset-reduced/train/Healthy/0a0d6a11-ddd6-4dac-8469-d5f65af5afca___RS_HL 0555_flipTB.JPG')
input_data = torchvision.transforms.functional.convert_image_dtype(input_data, dtype=torch.float32)
input_data = torchvision.transforms.Resize((512,512))(input_data)
input_data = input_data.unsqueeze(0)
print(input_shape)
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

# get_tensor() returns a copy of the tensor data
# use tensor() in order to get a pointer to the tensor
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)

pred = get_classification(output_data)
print(pred)



[  1   3 512 512]
[[ -7.0611405   1.625801    3.374012    3.253458  -20.40204    -3.9611185
   -6.650133  -23.108    ]]
Healthy


### Segmentation Inference

In [28]:
import numpy as np
import tensorflow as tf
import torch

# Load the TFLite model and allocate tensors
interpreter = tf.lite.Interpreter(model_path="tf_lite/segmodelv3.tflite")
interpreter.allocate_tensors()

# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test the model on random input data
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

# get_tensor() returns a copy of the tensor data
# use tensor() in order to get a pointer to the tensor
leaf_data = interpreter.get_tensor(output_details[0]['index'])
disease_data = interpreter.get_tensor(output_details[1]['index'])

print(leaf_data.shape)

(1, 3, 512, 512)
