In [1]:
import mlflow
import mlflow.pytorch
import mlflow.onnx
import onnx
import onnxruntime
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
import shutil
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define the neural network model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = torch.max_pool2d(self.conv2(x), 2)
        x = torch.relu(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

In [3]:
# Training settings
batch_size = 64
epochs = 5
lr = 0.01
momentum = 0.5
log_interval = 10

In [4]:
# Initialize data loaders
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = DataLoader(datasets.MNIST('../data', train=True, download=True, transform=transform), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(datasets.MNIST('../data', train=False, transform=transform), batch_size=batch_size, shuffle=False)


In [5]:
# Initialize model, loss function, and optimizer
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

In [6]:
def train_model():
    mlflow.set_experiment("MNIST_ONNX_Experiment")
    with mlflow.start_run() as run:
        mlflow.log_param("batch_size", batch_size)
        mlflow.log_param("epochs", epochs)
        mlflow.log_param("learning_rate", lr)
        mlflow.log_param("momentum", momentum)

        for epoch in range(epochs):
            model.train()
            for batch_idx, (data, target) in enumerate(train_loader):
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                if batch_idx % log_interval == 0:
                    mlflow.log_metric('loss', loss.item(), step=epoch * len(train_loader) + batch_idx)

            # Evaluate on test set
            model.eval()
            test_loss = 0
            correct = 0
            with torch.no_grad():
                for data, target in test_loader:
                    output = model(data)
                    test_loss += criterion(output, target).item()
                    pred = output.argmax(dim=1, keepdim=True)
                    correct += pred.eq(target.view_as(pred)).sum().item()

            test_loss /= len(test_loader.dataset)
            accuracy = 100. * correct / len(test_loader.dataset)
            mlflow.log_metric('test_loss', test_loss, step=epoch)
            mlflow.log_metric('accuracy', accuracy, step=epoch)

        # Convert and log the model in ONNX format
        dummy_input = torch.randn(1, 1, 28, 28)
        torch.onnx.export(model, dummy_input, "mnist_model.onnx")
        mlflow.onnx.log_model(onnx_model=onnx.load("mnist_model.onnx"), artifact_path="mnist_model")

    print('Training complete.')
    return model

In [7]:
def convert_to_onnx(model):
    dummy_input = torch.randn(1, 1, 28, 28)
    torch.onnx.export(model, dummy_input, "model.onnx")

In [8]:
def enable_dynamic_batching():
    model = onnx.load("model.onnx")
    graph = model.graph
    for input_tensor in graph.input:
        input_tensor.type.tensor_type.shape.dim[0].dim_param = 'batch_size'
    for output_tensor in graph.output:
        output_tensor.type.tensor_type.shape.dim[0].dim_param = 'batch_size'
    onnx.save(model, "model_dynamic.onnx")

In [9]:
def prepare_triton_repository():
    os.makedirs("model_repository/mnist_model/1", exist_ok=True)
    shutil.move("model_dynamic.onnx", "model_repository/mnist_model/1/model.onnx")

    input_name, output_name = get_model_io_names("model.onnx")

    with open("model_repository/mnist_model/config.pbtxt", "w") as f:
        f.write(f"""
name: "mnist_model"
platform: "onnxruntime_onnx"
max_batch_size: 0  # Enable dynamic batching
input [
  {{
    name: "{input_name}"
    data_type: TYPE_FP32
    dims: [ -1, 1, 28, 28 ]
  }}
]
output [
  {{
    name: "{output_name}"
    data_type: TYPE_FP32
    dims: [ -1, 10 ]
  }}
]
        """)

def get_model_io_names(onnx_model_path):
    model = onnx.load(onnx_model_path)
    input_name = model.graph.input[0].name
    output_name = model.graph.output[0].name
    return input_name, output_name


In [10]:
if __name__ == "__main__":
    trained_model = train_model()
    convert_to_onnx(trained_model)
    enable_dynamic_batching()
    prepare_triton_repository()
    print("Model training, logging, and Triton deployment preparation complete.")

2024/05/28 12:39:59 INFO mlflow.tracking.fluent: Experiment with name 'MNIST_ONNX_Experiment' does not exist. Creating a new experiment.


Training complete.
Model training, logging, and Triton deployment preparation complete.
