# Imports

In [1]:
import json

import mlflow
from mlflow.tracking import MlflowClient
import torch
import torch.onnx

# Query of experiments

In [2]:
# Initialize MLflow client
client = MlflowClient()
MODEL_NAME = "pet_classifier_model"

2025/12/08 13:43:30 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2025/12/08 13:43:30 INFO mlflow.store.db.utils: Updating database tables
2025/12/08 13:43:30 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2025/12/08 13:43:30 INFO alembic.runtime.migration: Will assume non-transactional DDL.
2025/12/08 13:43:30 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2025/12/08 13:43:30 INFO alembic.runtime.migration: Will assume non-transactional DDL.


In [3]:
# Get all versions of the registered model
model_versions = client.search_model_versions(f"name='{MODEL_NAME}'")

print(f"Found {len(model_versions)} registered model versions")
for version in model_versions:
    print(f"Version {version.version} - Run ID: {version.run_id}")

2025/12/08 13:43:30 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2025/12/08 13:43:30 INFO mlflow.store.db.utils: Updating database tables
2025/12/08 13:43:30 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2025/12/08 13:43:30 INFO alembic.runtime.migration: Will assume non-transactional DDL.


Found 8 registered model versions
Version 8 - Run ID: c2e4a26506ce45f4855f68d9e5a4e05d
Version 7 - Run ID: 8e07049a246544d08378c404719db7ca
Version 6 - Run ID: f8a4428ef1484acb8addcb9c78330f49
Version 5 - Run ID: f3a460bbe2754ea0ba1d3914a591203d
Version 4 - Run ID: 318924adb4eb47bcbd78f695ae8d36d4
Version 3 - Run ID: e28f408753b946e7b2225eef7d46b50c
Version 2 - Run ID: 54384ff0c3dd407ba59985e2011252d0
Version 1 - Run ID: f33e3a17346b4ec5a416259819265da0


In [6]:
best_version = None
best_test_acc = -1

for version in model_versions:
    run_id = version.run_id
    run = client.get_run(run_id)
    metrics = run.data.metrics
    
    # Get metrics
    train_acc = metrics.get('final_train_acc', -1)
    val_acc = metrics.get('final_val_acc', -1)
    test_acc = metrics.get('test_acc', -1)
    
    print((
        f"Version {version.version} - "
        f"Run ID: {run_id} - "
        f"Train Acc: {train_acc:.4f} - "
        f"Val Acc: {val_acc:.4f} - "
        f"Test Acc: {test_acc:.4f}"
    ))
    
    if test_acc > best_test_acc:
        best_test_acc = test_acc
        best_version = version

print(f"\nBest model: Version {best_version.version}")
print(f"Run ID: {best_version.run_id}")
print(f"Test Accuracy: {best_test_acc:.4f}")

Version 8 - Run ID: c2e4a26506ce45f4855f68d9e5a4e05d - Train Acc: 0.9655 - Val Acc: 0.7943 - Test Acc: 0.7727
Version 7 - Run ID: 8e07049a246544d08378c404719db7ca - Train Acc: 0.9736 - Val Acc: 0.9080 - Test Acc: 0.9120
Version 6 - Run ID: f8a4428ef1484acb8addcb9c78330f49 - Train Acc: 0.9406 - Val Acc: 0.9039 - Test Acc: 0.8999
Version 5 - Run ID: f3a460bbe2754ea0ba1d3914a591203d - Train Acc: 0.9396 - Val Acc: 0.8985 - Test Acc: 0.9053
Version 4 - Run ID: 318924adb4eb47bcbd78f695ae8d36d4 - Train Acc: 0.9406 - Val Acc: 0.9229 - Test Acc: 0.9066
Version 3 - Run ID: e28f408753b946e7b2225eef7d46b50c - Train Acc: 0.9320 - Val Acc: 0.9175 - Test Acc: 0.9093
Version 2 - Run ID: 54384ff0c3dd407ba59985e2011252d0 - Train Acc: 0.9406 - Val Acc: 0.9229 - Test Acc: 0.9066
Version 1 - Run ID: f33e3a17346b4ec5a416259819265da0 - Train Acc: 0.9312 - Val Acc: 0.9080 - Test Acc: 0.8931

Best model: Version 7
Run ID: 8e07049a246544d08378c404719db7ca
Test Accuracy: 0.9120


# Serialization

In [7]:
# Load the model
model_uri = f"runs:/{best_version.run_id}/model"
model = mlflow.pytorch.load_model(model_uri)

# Move to CPU and set to evaluation mode
model = model.to('cpu')
model.eval()

print("Model loaded successfully and moved to CPU")

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Model loaded successfully and moved to CPU


In [10]:
# dummy input for ONNX export
dummy_input = torch.randn(1, 3, 224, 224)

# Export to ONNX
onnx_model_path = f"mobilenetv2_pet_classifier_run{best_version.run_id[0:6]}.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_model_path,
    export_params=True,
    opset_version=18,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)

print(f"Model serialized to {onnx_model_path}")

  torch.onnx.export(


[torch.onnx] Obtain model graph for `MobileNetV2([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `MobileNetV2([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 105 of general pattern rewrite rules.
Model serialized to mobilenetv2_pet_classifier_run8e0704.onnx


# Class names

In [13]:
artifact_path = "class_labels.json"
local_path = client.download_artifacts(best_version.run_id, artifact_path)

print(f"Class labels downloaded to: {local_path}")

# Load the class labels
with open(local_path, 'r') as f:
    class_labels = json.load(f)

# Save to a new file for production use
output_labels_path = f"{onnx_model_path.replace('.onnx', '')}_class_labels.json"
with open(output_labels_path, 'w') as f:
    json.dump(class_labels, f, indent=2)

print(f"Class labels saved to {output_labels_path}")

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Class labels downloaded to: /tmp/tmpgfm0zbf5/class_labels.json
Class labels saved to mobilenetv2_pet_classifier_run8e0704_class_labels.json
