# Setup SageMaker MLflow Tracking Server

## Overview

This notebook guides you through creating and configuring an Amazon SageMaker MLflow tracking server that will be used by other observability notebooks in this workshop.

**Amazon SageMaker MLflow** is a fully managed service that provides MLflow tracking capabilities with enterprise-grade security, scalability, and integration with AWS services. It allows you to:

- Track experiments, parameters, metrics, and artifacts
- Trace agent workflows and LLM interactions
- Compare runs and visualize results through the MLflow UI
- Store metadata securely in AWS with automatic backups

**Resources:**
- [SageMaker MLflow Documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/mlflow.html)
- [MLflow IAM Permissions](https://docs.aws.amazon.com/sagemaker/latest/dg/mlflow-create-tracking-server-iam.html)

**Estimated completion time:** 30-40 minutes (most time is waiting for server provisioning)

## Prerequisites

Before starting this notebook, ensure you have:
- AWS account with appropriate permissions
- Access to Amazon SageMaker Studio
- IAM role with permissions to create SageMaker MLflow tracking servers
- An S3 bucket for storing MLflow artifacts (will be created if needed)

## What You'll Learn

By the end of this notebook, you will:
1. Understand SageMaker MLflow architecture and components
2. Create an S3 bucket for MLflow artifacts
3. Set up IAM roles and policies for MLflow
4. Create a SageMaker MLflow tracking server
5. Verify the tracking server is operational
6. Get the tracking server ARN for use in other notebooks

---

## Step 1: Install Required Packages

First, let's install the necessary Python packages.

In [None]:
# Install required packages
%pip install boto3 sagemaker mlflow sagemaker-mlflow --upgrade --quiet

<div style="background-color: #fff3cd; padding: 15px; border-left: 5px solid #ffc107; margin: 10px 0;">
    <strong>⚠️ Important:</strong> After running the cell above, please <strong>restart the kernel</strong> before continuing.
    <br><br>
    <strong>How to restart:</strong>
    <ul>
        <li>Click <strong>Kernel</strong> → <strong>Restart Kernel</strong> in the menu, OR</li>
        <li>Run the cell below to restart automatically</li>
    </ul>
</div>

In [None]:
# Uncomment and run this cell to restart the kernel automatically
# from IPython.core.display import HTML
# HTML("<script>Jupyter.notebook.kernel.restart()</script>")

## Step 2: Import Libraries and Initialize AWS Clients

**Note:** Run this cell after restarting the kernel.

In [None]:
from sagemaker.core.helper.session_helper import get_execution_role
role_arn = get_execution_role()

## Step 3: Configure MLflow Tracking Server Settings

Define the configuration for your MLflow tracking server. You can customize these values as needed.

**Important:** By default, this uses a fixed server name so you can reuse the same tracking server if you re-run the notebook.

In [None]:
# First, let's check if you already have any MLflow tracking servers
try:
    existing_servers = sagemaker_client.list_mlflow_tracking_servers()
    servers = existing_servers.get('TrackingServerSummaries', [])
    
    if servers:
        print("📋 Existing MLflow Tracking Servers:")
        print("="*80)
        for server in servers:
            print(f"\nName: {server['TrackingServerName']}")
            print(f"Status: {server['TrackingServerStatus']}")
            print(f"ARN: {server['TrackingServerArn']}")
            print(f"Created: {server.get('CreationTime', 'N/A')}")
        print("\n" + "="*80)
        print("\n💡 If you want to reuse an existing server, update the tracking_server_name below.")
    else:
        print("No existing MLflow tracking servers found.")
        print("A new one will be created in the next steps.")
except Exception as e:
    print(f"Could not list existing servers: {e}")

In [None]:
# Configuration
# Option 1: Use a fixed name (recommended - allows reusing the same server)
tracking_server_name = "mlflow-tracking-server-workshop"

# Option 2: Use a timestamp-based name (creates a new server each time)
# tracking_server_name = f"mlflow-tracking-server-{int(time.time())}"

artifact_bucket_name = f"sagemaker-mlflow-artifacts-{session.client('sts').get_caller_identity()['Account']}-{region}"

print(f"Tracking Server Name: {tracking_server_name}")
print(f"Artifact Bucket Name: {artifact_bucket_name}")
print("\n💡 Note: Using a fixed server name allows you to reuse the same tracking server.")
print("   If you want a new server each time, uncomment Option 2 above.")

## Step 4: Create S3 Bucket for MLflow Artifacts

MLflow needs an S3 bucket to store experiment artifacts, models, and other files.

In [None]:
def create_s3_bucket(bucket_name, region):
    """
    Create an S3 bucket for MLflow artifacts if it doesn't exist.
    """
    try:
        # Check if bucket already exists
        s3_client.head_bucket(Bucket=bucket_name)
        print(f"✓ Bucket '{bucket_name}' already exists.")
        return bucket_name
    except ClientError as e:
        error_code = e.response['Error']['Code']
        if error_code == '404':
            # Bucket doesn't exist, create it
            try:
                if region == 'us-east-1':
                    s3_client.create_bucket(Bucket=bucket_name)
                else:
                    s3_client.create_bucket(
                        Bucket=bucket_name,
                        CreateBucketConfiguration={'LocationConstraint': region}
                    )
                
                print(f"✓ Successfully created bucket '{bucket_name}'.")
                
                # Try to enable versioning (optional - not critical for MLflow)
                try:
                    s3_client.put_bucket_versioning(
                        Bucket=bucket_name,
                        VersioningConfiguration={'Status': 'Enabled'}
                    )
                    print(f"✓ Versioning enabled on bucket.")
                except ClientError as version_error:
                    print(f"⚠ Could not enable versioning (not critical): {version_error.response['Error']['Code']}")
                    print("  Bucket will work fine without versioning.")
                
                return bucket_name
            except ClientError as create_error:
                print(f"✗ Error creating bucket: {create_error}")
                raise
        else:
            print(f"✗ Error checking bucket: {e}")
            raise

# Create the S3 bucket
artifact_bucket = create_s3_bucket(artifact_bucket_name, region)
artifact_store_uri = f"s3://{artifact_bucket}/mlflow"
print(f"\nArtifact Store URI: {artifact_store_uri}")

## Step 5: Update IAM Role Policy for MLflow Access

Your SageMaker execution role needs permissions to access MLflow tracking servers and the S3 artifact bucket.

**Follow these steps to add the required permissions:**

In [None]:
import json

# Extract role name from ARN
role_name = role.split('/')[-1]

# Define the required policy
policy_document = {
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Action": "sagemaker-mlflow:*",
            "Resource": "*"
        },
        {
            "Effect": "Allow",
            "Action": [
                "s3:GetObject",
                "s3:PutObject",
                "s3:DeleteObject",
                "s3:ListBucket",
                "s3:GetBucketVersioning",
                "s3:PutBucketVersioning"
            ],
            "Resource": [
                f"arn:aws:s3:::{artifact_bucket_name}",
                f"arn:aws:s3:::{artifact_bucket_name}/*"
            ]
        }
    ]
}

print("="*80)
print("IAM ROLE POLICY UPDATE REQUIRED")
print("="*80)
print(f"\n📋 Your IAM Role ARN:")
print(f"   {role}")
print(f"\n📋 Role Name:")
print(f"   {role_name}")
print(f"\n📋 S3 Bucket:")
print(f"   {artifact_bucket_name}")
print("\n" + "="*80)
print("INSTRUCTIONS: Add this policy to your IAM role")
print("="*80)
print("\n1. Open AWS Console → IAM → Roles")
print(f"2. Search for and click on role: {role_name}")
print("3. Click 'Add permissions' → 'Create inline policy'")
print("4. Click 'JSON' tab")
print("5. Copy and paste the policy below:")
print("\n" + "-"*80)
print(json.dumps(policy_document, indent=2))
print("-"*80)
print("\n6. Click 'Review policy'")
print("7. Name it: SageMakerMLflowAccess")
print("8. Click 'Create policy'")

print("\n" + "="*80)
print("OPTION 2: Automatic Policy Addition")
print("="*80)
auto_add = input("\nWould you like to add this policy automatically? (yes/no): ").strip().lower()

if auto_add == 'yes':
    try:
        iam_client = session.client('iam')
        iam_client.put_role_policy(
            RoleName=role_name,
            PolicyName='SageMakerMLflowAccess',
            PolicyDocument=json.dumps(policy_document)
        )
        print(f"\u2713 Successfully added inline policy 'SageMakerMLflowAccess' to role '{role_name}'")
    except ClientError as e:
        print(f"\u2717 Could not add policy automatically: {e}")
        print("Please follow the manual instructions above.")
else:
    print("\nPlease follow the manual instructions above, then continue to Step 6.")

<div style="background-color: #d1ecf1; padding: 15px; border-left: 5px solid #0c5460; margin: 10px 0;">
    <strong>✋ Wait!</strong> Before proceeding to the next cell:
    <ol>
        <li>Complete the IAM policy update steps above</li>
        <li>Verify the policy was added successfully in the IAM console</li>
        <li>Then continue to Step 6 below</li>
    </ol>
</div>

## Step 6: Create SageMaker MLflow Tracking Server

Now we'll create the MLflow tracking server. This may take a few minutes.

In [None]:
def create_mlflow_tracking_server(server_name, artifact_store_uri, role_arn):
    """
    Create a SageMaker MLflow tracking server, or return existing one if already created.
    """
    # First, check if a server with this name already exists
    try:
        existing_server = sagemaker_client.describe_mlflow_tracking_server(
            TrackingServerName=server_name
        )
        print(f"✓ Tracking server '{server_name}' already exists.")
        print(f"  Status: {existing_server['TrackingServerStatus']}")
        print(f"  ARN: {existing_server['TrackingServerArn']}")
        print("\nSkipping creation. Using existing tracking server.")
        return existing_server['TrackingServerArn']
    except ClientError as e:
        if e.response['Error']['Code'] != 'ResourceNotFound':
            # Some other error occurred
            print(f"Error checking for existing server: {e}")
            raise
        # Server doesn't exist, proceed with creation
    
    # Create new tracking server
    try:
        print(f"Creating new MLflow tracking server: {server_name}")
        response = sagemaker_client.create_mlflow_tracking_server(
            TrackingServerName=server_name,
            ArtifactStoreUri=artifact_store_uri,
            RoleArn=role_arn,
            TrackingServerSize='Small',  # Options: Small, Medium, Large
            AutomaticModelRegistration=False
        )
        
        tracking_server_arn = response['TrackingServerArn']
        print(f"✓ Tracking server creation initiated.")
        print(f"  ARN: {tracking_server_arn}")
        print(f"\n⏰ The server is now being provisioned (takes 20-25 minutes).")
        print(f"Proceed to Step 7 to check the status.")
        
        return tracking_server_arn
    except ClientError as e:
        print(f"✗ Error creating tracking server: {e}")
        raise

# Create the tracking server (or get existing one)
tracking_server_arn = create_mlflow_tracking_server(
    tracking_server_name,
    artifact_store_uri,
    role
)

## Step 7: Check Tracking Server Status

The tracking server takes 20-25 minutes to provision. Let's check its current status.

**Note:** If the status is not 'Created', wait a few minutes and re-run this cell until it shows 'Created'.

In [None]:
try:
    response = sagemaker_client.describe_mlflow_tracking_server(
        TrackingServerName=tracking_server_name
    )
    
    status = response['TrackingServerStatus']
    
    print("="*80)
    print("TRACKING SERVER STATUS CHECK")
    print("="*80)
    print(f"\nServer Name: {tracking_server_name}")
    print(f"Current Status: {status}")
    print(f"Creation Time: {response.get('CreationTime', 'N/A')}")
    
    if status == 'Created':
        print("\n" + "="*80)
        print("✓ SUCCESS: Tracking server is ready!")
        print("="*80)
        print("\nYou can now proceed to Step 8.")
        server_info = response
    elif status == 'Creating':
        print("\n" + "="*80)
        print("⏳ WAIT: Server is still being created...")
        print("="*80)
        print("\n⏰ This typically takes 20-25 minutes.")
        print("Please wait 3-5 minutes and re-run this cell.")
        print("\n💡 Tip: You can also check status in SageMaker Studio:")
        print("   SageMaker Studio → MLflow (left sidebar) → Tracking servers")
        server_info = None
    elif status in ['CreateFailed', 'DeleteFailed']:
        print("\n" + "="*80)
        print(f"✗ ERROR: Tracking server creation failed!")
        print("="*80)
        if 'FailureReason' in response:
            print(f"\nFailure Reason: {response['FailureReason']}")
        print("\nPlease check the error and try creating a new tracking server.")
        server_info = None
    else:
        print(f"\n⚠️  Unexpected status: {status}")
        print("Please wait and re-run this cell.")
        server_info = None
    
    print("\n" + "="*80)
    
except ClientError as e:
    print(f"✗ Error checking server status: {e}")
    server_info = None

## Step 8: Verify Tracking Server and Get Details

Let's verify the tracking server is operational and get its details.

In [None]:
try:
    server_info
except NameError:
    print("⚠️  Please run Step 7 first to check the tracking server status.")
    server_info = None

if server_info:
    print("\n" + "="*80)
    print("MLflow Tracking Server Details")
    print("="*80)
    print(f"\nTracking Server Name: {server_info['TrackingServerName']}")
    print(f"Tracking Server ARN: {server_info['TrackingServerArn']}")
    print(f"Tracking Server URL: {server_info.get('TrackingServerUrl', 'N/A')}")
    print(f"Artifact Store URI: {server_info['ArtifactStoreUri']}")
    print(f"Status: {server_info['TrackingServerStatus']}")
    print(f"\n" + "="*80)
    print("\n⚠️  IMPORTANT: Save the Tracking Server ARN below for use in other notebooks!")
    print("="*80)
    print(f"\n{server_info['TrackingServerArn']}\n")
    print("="*80)
else:
    print("\n✗ Failed to create or verify tracking server. Please check the errors above.")

if server_info:
    # Store the ARN for use in other notebooks
    %store tracking_server_arn
    print("\n💡 Tip: In other notebooks, run '%store -r tracking_server_arn' to retrieve this value.")

## Step 9: Test Connection to MLflow Tracking Server

Let's verify we can connect to the tracking server using the MLflow Python client.

In [None]:
import mlflow
import boto3
from botocore.exceptions import ClientError

# Check if required variables are defined
try:
    tracking_server_name
    tracking_server_arn
except NameError:
    print("✗ Error: Required variables not found.")
    print("\nPlease run the previous steps first:")
    print("  - Step 2: Import libraries and initialize AWS clients")
    print("  - Step 3: Configure MLflow tracking server settings")
    print("  - Step 6: Create SageMaker MLflow tracking server")
    print("  - Step 7: Check tracking server status (ensure it's 'Created')")
    raise

# Initialize SageMaker client if not already defined
try:
    sagemaker_client
except NameError:
    session = boto3.Session()
    sagemaker_client = session.client('sagemaker')

# Check if tracking server is ready before testing connection
try:
    response = sagemaker_client.describe_mlflow_tracking_server(
        TrackingServerName=tracking_server_name
    )
    
    if response['TrackingServerStatus'] != 'Created':
        print("⚠️  Tracking server is not ready yet.")
        print(f"   Current status: {response['TrackingServerStatus']}")
        print("\nPlease go back to Step 7 and wait for the server to be 'Created' before testing the connection.")
    else:
        # Server is ready, test the connection
        try:
            # Set the tracking URI
            mlflow.set_tracking_uri(tracking_server_arn)
            
            # Create a test experiment
            experiment_name = "test-connection"
            mlflow.set_experiment(experiment_name)
            
            # Log a simple test run
            with mlflow.start_run(run_name="connection-test"):
                mlflow.log_param("test_param", "test_value")
                mlflow.log_metric("test_metric", 1.0)
            
            print("\n" + "="*80)
            print("✓ SUCCESS: Connected to MLflow tracking server!")
            print("="*80)
            print(f"✓ Created test experiment: {experiment_name}")
            print(f"✓ Logged test run successfully")
            print("\nYou can now view this in the MLflow UI (see Step 10).")
            print("="*80)
            
        except Exception as e:
            print(f"\n✗ Error connecting to MLflow tracking server: {e}")
            print("\nPossible issues:")
            print("  - Tracking server may not be fully ready yet")
            print("  - IAM permissions may not be configured correctly (check Step 5)")
            print("  - Network connectivity issues")
            
except ClientError as e:
    print(f"✗ Error checking tracking server: {e}")
    print("\nMake sure you've completed Steps 6 and 7 first.")

## Step 10: Access the MLflow UI

You can access the MLflow UI through SageMaker Studio:

**Steps to access:**
1. In SageMaker Studio, click **MLflow** in the left sidebar
2. You'll see your tracking server listed
3. Click on your tracking server name
4. Click **Open MLflow UI** button to view experiments, runs, and traces

**Note:** The MLflow tracking servers are managed through SageMaker Studio, not the main AWS Console.

In the MLflow UI, you should see:
- The "test-connection" experiment created in Step 9
- The test run with logged parameters and metrics

## Next Steps

Now that your MLflow tracking server is set up, you can:

1. **Use the tracking server ARN** in other notebooks:
   - Copy the ARN printed above
   - Replace `<YOUR-SAGEMAKERAI-MLFLOW-ARN>` in the observability notebooks with your ARN

2. **Explore the MLflow UI**:
   - View experiments and runs
   - Compare metrics across runs
   - Visualize traces for agent workflows

3. **Run the observability notebooks**:
   - `mlflow-crewAI-observability.ipynb` - Track CrewAI agent workflows
   - `mlflow-langgraph-observability.ipynb` - Track LangGraph agent workflows

## Cleanup

When you're done with the workshop, you can delete the tracking server to avoid ongoing charges:

In [None]:
# Uncomment and run this cell to delete the tracking server
# WARNING: This will delete all experiments and runs stored in the tracking server!

# try:
#     sagemaker_client.delete_mlflow_tracking_server(
#         TrackingServerName=tracking_server_name
#     )
#     print(f"Deleting tracking server: {tracking_server_name}")
#     print("Note: The S3 bucket with artifacts will not be deleted automatically.")
# except ClientError as e:
#     print(f"Error deleting tracking server: {e}")

# To also delete the S3 artifact bucket (optional):
# WARNING: This will permanently delete all stored artifacts!
# try:
#     # First, delete all objects in the bucket
#     s3 = boto3.resource('s3')
#     bucket = s3.Bucket(artifact_bucket_name)
#     bucket.objects.all().delete()
#     bucket.delete()
#     print(f"Deleted S3 bucket: {artifact_bucket_name}")
# except ClientError as e:
#     print(f"Error deleting bucket: {e}")

## Summary

In this notebook, you:

✓ Created an S3 bucket for MLflow artifacts  
✓ Configured IAM permissions for MLflow access  
✓ Created a SageMaker MLflow tracking server  
✓ Verified the tracking server is operational  
✓ Tested the connection using the MLflow Python client  

Your MLflow tracking server is now ready to use for tracking experiments and tracing agent workflows in the other notebooks!

**Remember to save your tracking server ARN for use in other notebooks.**