# Imports

In [1]:
import os
import json

import mlflow
from mlflow.tracking import MlflowClient
import torch
import torch.onnx
import onnx

from config import Config

# Query of experiments

In [2]:
# Initialize MLflow client
client = MlflowClient()
MODEL_NAME = Config.MODEL_NAME

2025/12/14 20:58:17 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2025/12/14 20:58:17 INFO mlflow.store.db.utils: Updating database tables
2025/12/14 20:58:17 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2025/12/14 20:58:17 INFO alembic.runtime.migration: Will assume non-transactional DDL.
2025/12/14 20:58:17 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2025/12/14 20:58:17 INFO alembic.runtime.migration: Will assume non-transactional DDL.


In [3]:
# Get all versions of the registered model
model_versions = client.search_model_versions(f"name='{MODEL_NAME}'")

print(f"Found {len(model_versions)} registered model versions")
for version in model_versions:
    print(f"Version {version.version} - Run ID: {version.run_id}")

2025/12/14 20:58:17 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2025/12/14 20:58:17 INFO mlflow.store.db.utils: Updating database tables
2025/12/14 20:58:17 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2025/12/14 20:58:17 INFO alembic.runtime.migration: Will assume non-transactional DDL.


Found 6 registered model versions
Version 6 - Run ID: 0bf33d7bf6dd49f28bcf7ed6a303bfed
Version 5 - Run ID: d03735a2502542bcadfd5a761b64e0a9
Version 4 - Run ID: 4b5aa21e4b584a1da5edac84e4d80710
Version 3 - Run ID: 481696f1fed245b8950fff507181bbf8
Version 2 - Run ID: 78fc84fa9f5447dc8b5b37d47474a3e7
Version 1 - Run ID: c84f294d9f1d47a292d25912c4c1c030


In [4]:
best_version = None
best_f1_score = -1

for version in model_versions:
    run_id = version.run_id
    run = client.get_run(run_id)
    metrics = run.data.metrics
    
    # Get metrics
    test_f1_score = metrics.get('test_f1_score', -1)
    
    print((
        f"Version {version.version} - "
        f"Run ID: {run_id} - "
        f"Test Acc: {test_f1_score:.4f}"
    ))
    
    if test_f1_score > best_f1_score:
        best_f1_score = test_f1_score
        best_version = version

print(f"\nBest model: Version {best_version.version}")
print(f"Run ID: {best_version.run_id}")
print(f"Test Accuracy: {best_f1_score:.4f}")

Version 6 - Run ID: 0bf33d7bf6dd49f28bcf7ed6a303bfed - Test Acc: 0.9875
Version 5 - Run ID: d03735a2502542bcadfd5a761b64e0a9 - Test Acc: 0.9891
Version 4 - Run ID: 4b5aa21e4b584a1da5edac84e4d80710 - Test Acc: 0.9891
Version 3 - Run ID: 481696f1fed245b8950fff507181bbf8 - Test Acc: 0.9907
Version 2 - Run ID: 78fc84fa9f5447dc8b5b37d47474a3e7 - Test Acc: 0.9891
Version 1 - Run ID: c84f294d9f1d47a292d25912c4c1c030 - Test Acc: 0.9876

Best model: Version 3
Run ID: 481696f1fed245b8950fff507181bbf8
Test Accuracy: 0.9907


# Serialization

In [5]:
# Load the model
model_uri = f"runs:/{best_version.run_id}/model"
model = mlflow.pytorch.load_model(model_uri)

# Move to CPU and set to evaluation mode
model = model.to('cpu')
model.eval()

print("Model loaded successfully and moved to CPU")

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/6 [00:00<?, ?it/s]

Model loaded successfully and moved to CPU


In [6]:
# dummy input for ONNX export
dummy_input = torch.randn(1, 3, 224, 224)

# Export to ONNX
onnx_model_path = f"{Config.MODEL_NAME}_{best_version.run_id[0:6]}.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_model_path,
    export_params=True,
    opset_version=18,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)

# Convert external data to embedded format
print("Converting to embedded format...")
onnx_model = onnx.load(onnx_model_path, load_external_data=True)
onnx.save(onnx_model, onnx_model_path)

# Remove the external data file if it exists
external_data_path = onnx_model_path + ".data"
if os.path.exists(external_data_path):
    os.remove(external_data_path)
    print(f"Removed {external_data_path}")

print(f"Model serialized to {onnx_model_path} (embedded weights)")

  torch.onnx.export(


Converting to embedded format...
Model serialized to cats_and_dogs_mobilenet_v2_481696.onnx (embedded weights)


# Class names

In [11]:
artifact_path = "labels.json"
local_path = client.download_artifacts(best_version.run_id, artifact_path)

print(f"Class labels downloaded to: {local_path}")

# Load the class labels
with open(local_path, 'r') as f:
    class_labels = json.load(f)

# Save to a new file for production use
output_labels_path = f"{onnx_model_path.replace('.onnx', '')}_labels.json"
with open(output_labels_path, 'w') as f:
    json.dump(class_labels, f, indent=2)

print(f"Class labels saved to {output_labels_path}")

best_params = client.get_run(best_version.run_id).data.params
best_params_path = f"{onnx_model_path.replace('.onnx', '')}_params.json"
with open(best_params_path, 'w') as f:
    json.dump(best_params, f, indent=2)

print(f"Params saved to {output_labels_path}")

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Class labels downloaded to: /tmp/tmphqrzfqe3/labels.json
Class labels saved to cats_and_dogs_mobilenet_v2_481696_labels.json
Params saved to cats_and_dogs_mobilenet_v2_481696_labels.json
