In [7]:
import tensorflow as tf
import numpy as np
import subprocess
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [2]:
os.makedirs("../converted_models/ONNX", exist_ok=True)
os.makedirs("../converted_models/TFLITE", exist_ok=True)

In [3]:


# # Convert the TensorFlow SavedModel directly to ONNX
# saved_model_dir = "../best_model/model1/best_f1score_fold"
# tflife_model_dir = "../converted_models/TFLITE/fp16_quantized_model.tflite"
# onnx_model_path = "../converted_models/ONNX/fp16_converted_model.onnx"

# # Convert the model
# # !python -m tflite2onnx.convert --saved-model {tflife_model_dir} --output {onnx_model_path} --opset 13
# !python -m tf2onnx.convert --tflite {tflife_model_dir} --output {onnx_model_path} --opset 13

# print(f"ONNX model saved at: {onnx_model_path}")


In [None]:
def convert_savedmodel_to_onnx(saved_model_dir: str):
    os.makedirs("../converted_models", exist_ok=True)
    onnx_model_dir = "../converted_models/converted_model.onnx"
    
    tflite2onnx_command = [
        "python", "-m", "tf2onnx.convert",
        "--saved-model", saved_model_dir,
        "--output", onnx_model_dir,
        "--opset", "13"
    ]

    print(tflite2onnx_command)

    # Run the command
    try:
        subprocess.run(tflite2onnx_command, check=True)
        print(f"ONNX model successfully saved at: {onnx_model_dir}")
    except subprocess.CalledProcessError as e:
        print(f"Error occurred during ONNX conversion: {e}")
        raise

def convert_savedmodel_to_tflite(
        saved_model_dir: str,
        representative_dataset_gen=None,
        quantization_type: str = "dynamic",
        tflite_model_path: str = "quantized_model.tflite",
    ):
    """
    Converts a TensorFlow SavedModel to a quantized TensorFlow Lite model, and then to ONNX format using CLI tools.
    
    Parameters:
        saved_model_dir (str): Path to the TensorFlow SavedModel directory.
        representative_dataset_gen (callable): A generator function providing representative data for calibration.
                                               Required for integer quantization.
        quantization_type (str): Type of post-training quantization. Options: "dynamic", "integer", "fp16".
        tflite_model_path (str): Path to save the quantized TFLite model.
    """
    # Load the TensorFlow model
    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)

    # Apply the selected quantization method
    if quantization_type == "dynamic":
        converter.optimizations = [tf.lite.Optimize.DEFAULT]  # Dynamic range quantization
    elif quantization_type == "fp16":
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.target_spec.supported_types = [tf.float16]
    elif quantization_type == "integer":
        if representative_dataset_gen is None:
            raise ValueError("A representative dataset generator must be provided for integer quantization.")
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.representative_dataset = representative_dataset_gen
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    else:
        raise ValueError(f"Unsupported quantization type: {quantization_type}")

    # Convert the model to TensorFlow Lite format
    tflite_model = converter.convert()
    print(f"{quantization_type} quantization complete!")

    # Save the TFLite model
    tflite_model_path = f"../converted_models/TFLITE/{quantization_type}_quantized_model.tflite"
    with open(tflite_model_path, 'wb') as f:
        f.write(tflite_model)

    print(f"Quantized TFLite model saved at: {tflite_model_path}")

def convert_tflite_to_onnx(
    tflite_model_path: str,
    quantization_type: str = "dynamic",
    onnx_model_path: str = "converted_model.onnx",
    opset: str = "13"
    ):
    """
    Converts a TensorFlow Lite model to ONNX format using CLI tools.
    
    Parameters:
        tflite_model_path (str): Path to save the quantized TFLite model.
        quantization_type (str): Type of post-training quantization. Options: "dynamic", "integer", "fp16".
        onnx_model_path (str): Path to save the converted ONNX model.
        opset: TODO
    """

    # Convert TFLite to ONNX using command-line tool
    # Construct the command
    onnx_model_path = f"../converted_models/ONNX/{quantization_type}_quantized_model.onnx"
    # python_env = sys.executable  # This gets the same Python interpreter used by the script
    tflite2onnx_command = [
        "python", "-m", "tf2onnx.convert",
        "--tflite", tflite_model_path,
        "--output", onnx_model_path,
        "--opset", "13"
    ]

    print(tflite2onnx_command)

    # Run the command
    try:
        result = subprocess.run(tflite2onnx_command, check=True, capture_output=True, text=True)
        print(f"ONNX model successfully saved at: {onnx_model_path}")
        print("Output:", result.stdout)
    except subprocess.CalledProcessError as e:
        print(f"Error occurred during ONNX conversion: {e}")
        print("stderr:", e.stderr)
        raise

# representative dataset generator (for integer quantization)
def create_data_generator(data_dir, batch_size=32, img_size=(224, 224)):
    """
    Creates a data generator that yields batches of data for quantization.

    Args:
        data_dir (str): Path to the data directory containing 'MonkeyPox' and 'Others' subdirectories.
        batch_size (int): The size of the batches to return. Default is 32.
        img_size (tuple): Tuple representing the target image size (height, width). Default is (224, 224).

    Returns:
        generator: A Python generator that yields batches of image data.
    """
    
    # Initialize an ImageDataGenerator for preprocessing
    datagen = ImageDataGenerator(rescale=1./255)  # Rescale pixel values to [0, 1]
    
    # Create a flow from the directory
    # The target_size should match the input size of your model (e.g., 224x224 for many pre-trained models)
    generator = datagen.flow_from_directory(
        data_dir,
        target_size=img_size,  # Resize all images to this size
        batch_size=batch_size,
        class_mode=None,  # We don't need the labels for quantization
        # shuffle=True,  # Shuffle the images
        seed=42,  # Set a seed for reproducibility
    )
    
    return generator

generator = create_data_generator("../data/Augmented_Images")
# Define a representative dataset function for quantization
def representative_dataset_gen():
    """
    A generator function that yields batches of data for the INT8 calibration.
    """
    for batch in generator:
        # Yield a batch of images from the generator as a numpy array
        yield [batch]


Found 3192 images belonging to 2 classes.


In [None]:
# convert_savedmodel_to_onnx("../best_model/model1/best_f1score_fold")
quantization_types = ["integer"]
# quantization_types = ["dynamic", "fp16"]
# quantization_types = ["dynamic", "fp16", "integer"]
saved_model_dir = "../best_model/model1/best_f1score_fold"

for quant_type in quantization_types:
    convert_savedmodel_to_tflite(saved_model_dir, quantization_type=quant_type, representative_dataset_gen=representative_dataset_gen)
    tltite_model_dir = f"../converted_models/TFLITE/{quant_type}_quantized_model.tflite"
    # convert_tflite_to_onnx(tltite_model_dir, quantization_type=quant_type)



integer quantization complete!
