# SageMaker V3 Train-to-Inference E2E with MLflow Integration

This notebook demonstrates the complete end-to-end workflow from training a custom PyTorch model to deploying it for inference on SageMaker cloud infrastructure, with MLflow 3.x tracking and model registry integration.

### Prerequisites
- SageMaker MLflow App created (tracking server ARN required)
- IAM permissions for MLflow tracking and model registry
- AWS credentials configured

## Step 0: Install Dependencies

**Note:** There are known issues with MLflow model path resolution. Install the latest published SDK from GitHub for the latest fixes.

In [None]:
# Install from local SDK for development (includes fixes for MLflow path resolution issues)
%pip install -e ../../sagemaker-core -e ../../sagemaker-train -e ../../sagemaker-serve -e ../../sagemaker-mlops -e ../../. "mlflow==3.4.0" --upgrade

#### NOTE: You must restart your kernel

## Step 1: Configuration

Set up MLflow tracking server and training configuration.

In [None]:
import uuid
from sagemaker.core import image_uris

# =============================================================================
# MLflow Configuration - UPDATE THIS WITH YOUR TRACKING SERVER ARN
# =============================================================================
# Eg. "arn:aws:sagemaker:us-east-1:12345678:mlflow-app/app-ABCDEFGH123"
MLFLOW_TRACKING_ARN = "XXXXX"

# AWS Configuration
AWS_REGION = "us-east-1"

# Get PyTorch training image dynamically
PYTORCH_TRAINING_IMAGE = image_uris.retrieve(
    framework="pytorch",
    region=AWS_REGION,
    version="2.5",
    py_version="py311",
    instance_type="ml.m5.xlarge",
    image_scope="training"
)
print(f"Using PyTorch training image: {PYTORCH_TRAINING_IMAGE}")

# Naming prefixes
MODEL_NAME_PREFIX = "mlflow-e2e-model"
ENDPOINT_NAME_PREFIX = "mlflow-e2e-endpoint"
TRAINING_JOB_PREFIX = "mlflow-e2e-pytorch"
MLFLOW_EXPERIMENT_NAME = "sagemaker-v3-e2e-training"
MLFLOW_REGISTERED_MODEL_NAME = "pytorch-simple-classifier"

# Generate unique identifiers
unique_id = str(uuid.uuid4())[:8]
training_job_name = f"{TRAINING_JOB_PREFIX}-{unique_id}"
model_name = f"{MODEL_NAME_PREFIX}-{unique_id}"
endpoint_name = f"{ENDPOINT_NAME_PREFIX}-{unique_id}"

print(f"Training job name: {training_job_name}")
print(f"Model name: {model_name}")
print(f"Endpoint name: {endpoint_name}")

## Step 2: Connect to MLflow Tracking Server

In [None]:
import mlflow

# Connect to SageMaker MLflow tracking server
mlflow.set_tracking_uri(MLFLOW_TRACKING_ARN)

# Create or get experiment
mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)

print(f"Connected to MLflow tracking server")
print(f"Experiment: {MLFLOW_EXPERIMENT_NAME}")

## Step 3: Create Training Code with MLflow Logging

Create a PyTorch training script that logs metrics and registers the model to MLflow.

In [None]:
import tempfile
import os

def create_pytorch_training_code_with_mlflow(mlflow_tracking_arn, experiment_name, registered_model_name):
    """Create PyTorch training script with MLflow integration."""
    temp_dir = tempfile.mkdtemp()
    
    train_script = f'''import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import os
import mlflow
import mlflow.pytorch
from mlflow.models import infer_signature

class SimpleModel(nn.Module):
    def __init__(self, input_dim=4, output_dim=2):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        return torch.softmax(self.linear(x), dim=1)

def train():
    # MLflow setup
    mlflow.set_tracking_uri("{mlflow_tracking_arn}")
    mlflow.set_experiment("{experiment_name}")
    
    # Hyperparameters
    learning_rate = 0.01
    epochs = 10
    batch_size = 32
    input_dim = 4
    output_dim = 2
    
    with mlflow.start_run() as run:
        # Log hyperparameters
        mlflow.log_params({{
            "learning_rate": learning_rate,
            "epochs": epochs,
            "batch_size": batch_size,
            "input_dim": input_dim,
            "output_dim": output_dim,
            "optimizer": "Adam",
            "loss_function": "CrossEntropyLoss"
        }})
        
        model = SimpleModel(input_dim, output_dim)
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        criterion = nn.CrossEntropyLoss()
        
        # Synthetic data
        X = torch.randn(100, input_dim)
        y = torch.randint(0, output_dim, (100,))
        dataset = TensorDataset(X, y)
        dataloader = DataLoader(dataset, batch_size=batch_size)
        
        # Training loop with metric logging
        model.train()
        for epoch in range(epochs):
            epoch_loss = 0.0
            correct = 0
            total = 0
            
            for batch_x, batch_y in dataloader:
                optimizer.zero_grad()
                outputs = model(batch_x)
                loss = criterion(outputs, batch_y)
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += batch_y.size(0)
                correct += (predicted == batch_y).sum().item()
            
            avg_loss = epoch_loss / len(dataloader)
            accuracy = correct / total
            
            # Log metrics per epoch
            mlflow.log_metrics({{
                "train_loss": avg_loss,
                "train_accuracy": accuracy
            }}, step=epoch)
            
            print(f"Epoch {{epoch+1}}/{{epochs}} - Loss: {{avg_loss:.4f}}, Accuracy: {{accuracy:.4f}}")
        
        # Log final metrics
        mlflow.log_metrics({{
            "final_loss": avg_loss,
            "final_accuracy": accuracy
        }})
        
        # Infer signature and register model to MLflow
        model.eval()
        signature = infer_signature(
            X.numpy(),
            model(X).detach().numpy()
        )
        
        # Log and register model in one step
        mlflow.pytorch.log_model(
            model,
            name="{registered_model_name}",
            signature=signature,
            registered_model_name="{registered_model_name}"
        )
        
        print(f"Model registered to MLflow: {registered_model_name}")
        print(f"Run ID: {{run.info.run_id}}")
        
        print("Training completed!")

if __name__ == "__main__":
    train()
'''
    
    with open(os.path.join(temp_dir, 'train.py'), 'w') as f:
        f.write(train_script)
    
    with open(os.path.join(temp_dir, 'requirements.txt'), 'w') as f:
        f.write('mlflow==3.4.0\nsagemaker-mlflow==0.2.0\ncloudpickle==3.1.2\n')
    
    return temp_dir

# Create training code
training_code_dir = create_pytorch_training_code_with_mlflow(
    MLFLOW_TRACKING_ARN, 
    MLFLOW_EXPERIMENT_NAME,
    MLFLOW_REGISTERED_MODEL_NAME
)
print(f"Training code created in: {training_code_dir}")

## Step 4: Create ModelTrainer and Start Training

Use ModelTrainer to run the training script on SageMaker managed infrastructure. The training job will log metrics to MLflow and register the model to the MLflow registry.

In [None]:
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import SourceCode

# Training on SageMaker managed infrastructure
model_trainer = ModelTrainer(
    training_image=PYTORCH_TRAINING_IMAGE,
    source_code=SourceCode(
        source_dir=training_code_dir,
        entry_script="train.py",
        requirements="requirements.txt",
    ),
    base_job_name=training_job_name,
)

# Start training job
print(f"Starting training job: {training_job_name}")
print("Metrics will be logged to MLflow during training...")

model_trainer.train() 
print("Training completed! Check MLflow UI for metrics and registered model.")

## Step 5: Get Registered Model from MLflow

Retrieve the registered model from MLflow to get the model URI (`models:/<name>/<version>`) needed for deployment with ModelBuilder.

In [None]:
# Get the latest version of the registered model
from mlflow import MlflowClient

client = MlflowClient()
registered_model = client.get_registered_model(name=MLFLOW_REGISTERED_MODEL_NAME)

latest_version = registered_model.latest_versions[0]
model_version = latest_version.version
model_source = latest_version.source

# Get S3 URL of model files (for info only)
artifact_uri = client.get_model_version_download_uri(MLFLOW_REGISTERED_MODEL_NAME, model_version)

# MLflow model registry path to use with ModelBuilder
mlflow_model_path = f"models:/{MLFLOW_REGISTERED_MODEL_NAME}/{model_version}"

print(f"Registered Model: {MLFLOW_REGISTERED_MODEL_NAME}")
print(f"Latest Version: {model_version}")
print(f"Source: {model_source}")
print(f"Model artifacts location: {artifact_uri}")

## Step 6: Deploy from MLflow Model Registry

Use ModelBuilder to deploy the model directly from MLflow registry to a SageMaker endpoint.

In [None]:
import json
import torch
from sagemaker.serve.marshalling.custom_payload_translator import CustomPayloadTranslator
from sagemaker.serve.builder.schema_builder import SchemaBuilder

# =============================================================================
# Custom translators for PyTorch tensor conversion
# 
# PyTorch models expect tensors, but SageMaker endpoints communicate via JSON.
# These translators handle the conversion between JSON payloads and PyTorch tensors.
# =============================================================================

class PyTorchInputTranslator(CustomPayloadTranslator):
    """Handles input serialization/deserialization for PyTorch models."""
    def __init__(self):
        super().__init__(content_type='application/json', accept_type='application/json')
    
    def serialize_payload_to_bytes(self, payload: object) -> bytes:
        if isinstance(payload, torch.Tensor):
            return json.dumps(payload.tolist()).encode('utf-8')
        return json.dumps(payload).encode('utf-8')
    
    def deserialize_payload_from_stream(self, stream) -> object:
        data = json.load(stream)
        return torch.tensor(data, dtype=torch.float32)

class PyTorchOutputTranslator(CustomPayloadTranslator):
    """Handles output serialization/deserialization for PyTorch models."""
    def __init__(self):
        super().__init__(content_type='application/json', accept_type='application/json')
    
    def serialize_payload_to_bytes(self, payload: object) -> bytes:
        if isinstance(payload, torch.Tensor):
            return json.dumps(payload.tolist()).encode('utf-8')
        return json.dumps(payload).encode('utf-8')
    
    def deserialize_payload_from_stream(self, stream) -> object:
        return json.load(stream)

# Sample input/output for schema inference
sample_input = [[0.1, 0.2, 0.3, 0.4]]
sample_output = [[0.8, 0.2]]

schema_builder = SchemaBuilder(
    sample_input=sample_input,
    sample_output=sample_output,
    input_translator=PyTorchInputTranslator(),
    output_translator=PyTorchOutputTranslator()
)

In [None]:
from sagemaker.serve.model_builder import ModelBuilder
from sagemaker.serve.mode.function_pointers import Mode

# Cloud deployment to SageMaker endpoint
model_builder = ModelBuilder(
    mode=Mode.SAGEMAKER_ENDPOINT,
    schema_builder=schema_builder,
    model_metadata={
        "MLFLOW_MODEL_PATH": mlflow_model_path,
        "MLFLOW_TRACKING_ARN": MLFLOW_TRACKING_ARN
    },
    dependencies={"auto": False, "custom": ["mlflow==3.4.0", "sagemaker==3.3.1", "numpy==2.4.1", "cloudpickle==3.1.2"]},
)

print(f"ModelBuilder configured with MLflow model: {mlflow_model_path}")

In [None]:
# Build the model
core_model = model_builder.build(model_name=model_name, region=AWS_REGION)
print(f"Model built: {core_model.model_name}")

In [None]:
# Deploy to SageMaker endpoint
core_endpoint = model_builder.deploy(
    endpoint_name=endpoint_name,
    initial_instance_count=1
)

print(f"Endpoint deployed: {core_endpoint.endpoint_name}")

## Step 7: Test the Deployed Model

Invoke the endpoint with a sample input. The model returns class probabilities (2 classes) as a softmax output.

In [None]:
import boto3

# Test with JSON input
test_data = [[0.1, 0.2, 0.3, 0.4]]

runtime_client = boto3.client('sagemaker-runtime')
response = runtime_client.invoke_endpoint(
    EndpointName=core_endpoint.endpoint_name,
    Body=json.dumps(test_data),
    ContentType='application/json'
)

prediction = json.loads(response['Body'].read().decode('utf-8'))
print(f"Input: {test_data}")
print(f"Prediction: {prediction}")

## Step 8: Clean Up Resources

In [None]:
import shutil
from sagemaker.core.resources import EndpointConfig

# Clean up AWS resources
core_endpoint_config = EndpointConfig.get(endpoint_config_name=core_endpoint.endpoint_name)
core_model.delete()
core_endpoint.delete()
core_endpoint_config.delete()
print("AWS resources cleaned up!")

# Clean up training code directory
try:
    shutil.rmtree(training_code_dir)
    print("Cleaned up training code directory")
except Exception as e:
    print(f"Could not clean up training code: {e}")

print("Note: MLflow experiment runs and registered models are preserved.")

## Summary

This notebook demonstrates cloud deployment of a PyTorch model with MLflow integration:

1. **Training**: Runs on SageMaker managed infrastructure with ModelTrainer
2. **MLflow Integration**: Logs metrics, parameters, and registers model to MLflow registry
3. **Deployment**: Uses ModelBuilder to deploy directly from MLflow registry to a SageMaker endpoint
4. **Inference**: Invokes the endpoint with JSON payloads

Key MLflow integration points:
- `mlflow.log_params()` - hyperparameters
- `mlflow.log_metrics()` - training metrics per epoch
- `mlflow.pytorch.log_model()` - model artifact with registry
- `ModelBuilder` with `MLFLOW_MODEL_PATH` - deploy from registry

Key patterns:
- Custom `PayloadTranslator` classes for PyTorch tensor serialization
