# SageMaker V3 Train-to-Inference E2E Example

This notebook demonstrates the complete end-to-end workflow from training a custom PyTorch model to deploying it for inference using SageMaker V3.

In [None]:
# Import required libraries
import json
import uuid
import tempfile
import os
import boto3

from sagemaker.serve.model_builder import ModelBuilder
from sagemaker.serve.builder.schema_builder import SchemaBuilder
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import SourceCode
from sagemaker.core.resources import EndpointConfig
from sagemaker.core.helper.session_helper import Session

## Step 1: Configure Training Job

Set up a custom PyTorch model for training. We'll create a simple neural network for demonstration.

In [None]:
# Configuration for training
MODEL_NAME_PREFIX = "train-inf-v3-example-model"
ENDPOINT_NAME_PREFIX = "train-inf-v3-example-endpoint"
TRAINING_JOB_PREFIX = "e2e-v3-pytorch"

# AWS Configuration
AWS_REGION = "us-west-2"
PYTORCH_TRAINING_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.13.1-cpu-py39"

# Generate unique identifiers
unique_id = str(uuid.uuid4())[:8]
training_job_name = f"{TRAINING_JOB_PREFIX}-{unique_id}"
model_name = f"{MODEL_NAME_PREFIX}-{unique_id}"
endpoint_name = f"{ENDPOINT_NAME_PREFIX}-{unique_id}"

print(f"Training job name: {training_job_name}")
print(f"Model name: {model_name}")
print(f"Endpoint name: {endpoint_name}")

## Step 2: Create Training Code

Create a simple PyTorch training script and requirements file.

In [None]:
def create_pytorch_training_code():
    """Create PyTorch training script."""
    temp_dir = tempfile.mkdtemp()
    
    train_script = '''import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import os

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(4, 2)
    
    def forward(self, x):
        return torch.softmax(self.linear(x), dim=1)

def train():
    model = SimpleModel()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    
    # Synthetic data
    X = torch.randn(100, 4)
    y = torch.randint(0, 2, (100,))
    dataset = TensorDataset(X, y)
    dataloader = DataLoader(dataset, batch_size=32)
    
    # Train for 1 epoch
    model.train()
    for batch_x, batch_y in dataloader:
        optimizer.zero_grad()
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
    
    # Save model for TorchServe
    model.eval()
    traced_model = torch.jit.trace(model, torch.randn(1, 4))
    
    model_dir = os.environ.get('SM_MODEL_DIR', '/opt/ml/model')
    os.makedirs(model_dir, exist_ok=True)
    torch.jit.save(traced_model, os.path.join(model_dir, 'model.pth'))
    
    print("Training completed and model saved!")

if __name__ == "__main__":
    train()
'''
    
    with open(os.path.join(temp_dir, 'train.py'), 'w') as f:
        f.write(train_script)
    
    with open(os.path.join(temp_dir, 'requirements.txt'), 'w') as f:
        f.write('torch>=1.13.0,<2.0.0\n')
    
    return temp_dir

# Create training code
training_code_dir = create_pytorch_training_code()
print(f"Training code created in: {training_code_dir}")

## Step 3: Create ModelTrainer and Start Training

Set up the ModelTrainer with custom PyTorch code and launch the training job.

In [None]:
# Create SageMaker session
boto_session = boto3.Session(region_name=AWS_REGION)
sagemaker_session = Session(boto_session=boto_session)

# Create ModelTrainer with custom code
model_trainer = ModelTrainer(
    sagemaker_session=sagemaker_session,
    training_image=PYTORCH_TRAINING_IMAGE,
    source_code=SourceCode(
        source_dir=training_code_dir,
        entry_script="train.py",
        requirements="requirements.txt",
    ),
    base_job_name=training_job_name
)

# Start training job
print(f"Starting training job: {training_job_name}")
print("Note: This will take a few minutes to complete.")

model_trainer.train()
print("Model Training Completed!")

## Step 4: Create Schema Builder and Inference Spec

Set up the schema and inference specification for the trained model.

In [None]:
# Create schema builder for tensor-based models
def create_schema_builder():
    """Create schema builder for tensor-based models."""
    sample_input = [[0.1, 0.2, 0.3, 0.4]]
    sample_output = [[0.8, 0.2]]
    return SchemaBuilder(sample_input, sample_output)

# Create inference specification
class SimpleInferenceSpec(InferenceSpec):
    def load(self, model_dir):
        import torch
        return torch.jit.load(f"{model_dir}/model.pth")
    
    def invoke(self, input_object, model):
        import torch
        return model(torch.tensor(input_object)).tolist()

schema_builder = create_schema_builder()
print("Schema builder and inference spec created successfully!")

## Step 5: Create ModelBuilder and Build Model

Create the ModelBuilder with the trained model and build it for deployment.

In [None]:
# Create ModelBuilder with trained model
model_builder = ModelBuilder(
    model=model_trainer,
    schema_builder=schema_builder,
    model_server=ModelServer.TORCHSERVE,
    inference_spec=SimpleInferenceSpec(),
    image_uri=PYTORCH_TRAINING_IMAGE.replace("training", "inference"),
    dependencies={"auto": False},
)

# Build the trained model
core_model = model_builder.build(model_name=model_name, region=AWS_REGION)
print(f"Model Successfully Created: {core_model.model_name}")

## Step 6: Deploy the Trained Model

Deploy the trained model to a SageMaker endpoint.

In [None]:
# Deploy the trained model
core_endpoint = model_builder.deploy(
    endpoint_name=endpoint_name,
    initial_instance_count=1
)
print(f"Endpoint Successfully Created: {core_endpoint.endpoint_name}")

## Step 7: Test the Trained Model

Test the deployed trained model with sample tensor inputs.

In [None]:
# Test the trained model with tensor input
test_data = [[0.1, 0.2, 0.3, 0.4]]

result = core_endpoint.invoke(
    body=json.dumps(test_data),
    content_type="application/json"
)

# Decode and display the result
prediction = json.loads(result.body.read().decode('utf-8'))
print(f"Result of invoking endpoint: {prediction}")

In [None]:
# Test with different tensor inputs
test_inputs = [
    [[0.5, 0.3, 0.2, 0.1]],
    [[0.9, 0.1, 0.8, 0.2]],
    [[0.2, 0.7, 0.4, 0.6]]
]

for i, test_input in enumerate(test_inputs, 1):
    result = core_endpoint.invoke(
        body=json.dumps(test_input),
        content_type="application/json"
    )
    
    prediction = json.loads(result.body.read().decode('utf-8'))
    print(f"Test {i} - Input {test_input}: {prediction}")
    print("-" * 50)

## Step 8: Clean Up Resources

Clean up all created resources including the trained model and endpoint.

In [None]:
# Clean up resources
core_endpoint_config = EndpointConfig.get(endpoint_config_name=core_endpoint.endpoint_name)

# Delete in the correct order
core_model.delete()
core_endpoint.delete()
core_endpoint_config.delete()

print("Model and Endpoint Successfully Deleted!")
print(f"Note: Training job artifacts remain in S3 for reference.")

## Summary

This notebook demonstrated the complete E2E workflow:
1. Creating custom PyTorch training code
2. Configuring a ModelTrainer with custom source code
3. Running a training job on SageMaker
4. Building a ModelBuilder from training artifacts
5. Deploying the trained model to an endpoint
6. Testing the trained model with tensor inputs
7. Proper cleanup of inference resources

## Key Benefits of E2E Training:
- **Custom training**: Full control over PyTorch training process
- **Seamless workflow**: Train → Build → Deploy in one pipeline
- **Artifact management**: Automatic handling of training outputs
- **TorchServe integration**: Easy deployment with TorchServe
- **Production ready**: Trained models ready for immediate deployment

The V3 ModelBuilder makes it easy to go from custom training to production inference with minimal code!