# SageMaker Hyperscaler Training for Legal Reasoning Model (Part 3)

This notebook demonstrates how to train the Legal Reasoning Model using SageMaker Hyperscaler on ml.g5.8xlarge instances for optimal price-performance.

## Part 3: Training Job Configuration and Execution

### Setup

First, let's import the necessary libraries and load the configuration.

In [None]:
import os
import json
import yaml
import boto3
import sagemaker
from sagemaker.huggingface import HuggingFace
import datetime

# Load configuration
config_path = "../configs/hyperscaler_config.yaml"

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Set AWS region
region = 'us-east-1'  # Change to your preferred region

# Create SageMaker session
boto_session = boto3.Session(region_name=region)
sagemaker_session = sagemaker.Session(boto_session=boto_session)
role = sagemaker.get_execution_role()
bucket = sagemaker_session.default_bucket()
prefix = 'legal-reasoning-model'

### Create SageMaker Estimator

Create a SageMaker HuggingFace estimator with Hyperscaler configuration.

In [None]:
# Load distribution and hyperparameters from previous notebook
# In a real notebook, you would define these here or load from a file

# For brevity, we'll assume distribution and hyperparameters are defined
# Please run Part 2 notebook first to define these variables

# Set up training job name
timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
job_name = f"legal-reasoning-{config['model']['language']}-{timestamp}"

# Set up instance configuration
instance_type = config['aws']['instance_type']
instance_count = config['aws']['instance_count']
use_spot = config['aws']['use_spot']
max_wait = config['aws']['max_wait'] if use_spot else None
max_run = config['aws']['max_run']

# Set up output path
output_path = f"s3://{bucket}/{prefix}/model"
checkpoint_s3_uri = f"s3://{bucket}/{prefix}/checkpoints" if use_spot else None

print(f"Job Name: {job_name}")
print(f"Instance Type: {instance_type}")
print(f"Using Spot Instances: {use_spot}")
print(f"Output Path: {output_path}")

In [None]:
# Create HuggingFace estimator
huggingface_estimator = HuggingFace(
    entry_point="train_hyperscaler.py",  # Custom training script with model parallelism
    source_dir="../src/training",         # Directory containing the training code
    instance_type=instance_type,
    instance_count=instance_count,
    role=role,
    transformers_version="4.28.1",
    pytorch_version="2.0.0",
    py_version="py310",
    hyperparameters=hyperparameters,
    distribution=distribution,
    use_spot_instances=use_spot,
    max_wait=max_wait,
    max_run=max_run,
    checkpoint_s3_uri=checkpoint_s3_uri,
    output_path=output_path,
    base_job_name=job_name,
    sagemaker_session=sagemaker_session
)

### Start Training Job

Define data channels and start the training job.

In [None]:
# Define data channels
data_channels = {
    "train": f"s3://{bucket}/{prefix}/data/train",
    "validation": f"s3://{bucket}/{prefix}/data/validation"
}

print("Data Channels:")
print(json.dumps(data_channels, indent=2))

In [None]:
# Start training job
# Note: This will actually start the training job on SageMaker
# Uncomment the following line to run the training job

# huggingface_estimator.fit(inputs=data_channels, job_name=job_name)

print("To start the training job, uncomment the line above.")
print("Note: This will incur AWS charges for SageMaker training.")

### Monitor Training Job

Monitor the training job progress.

In [None]:
# Check training job status
# Note: This assumes you've started a training job

# Replace with your actual job name if you started a training job
example_job_name = job_name

print(f"To check the status of your training job, run:")
print(f"aws sagemaker describe-training-job --training-job-name {example_job_name}")

print("\nTo view the logs, run:")
print(f"aws logs get-log-events --log-group-name /aws/sagemaker/TrainingJobs --log-stream-name {example_job_name}/algo-1-XXXXXXXXXX")

### Cost Analysis

Analyze the cost of the training job.

In [None]:
# Calculate estimated cost
estimated_hours = 25  # Estimated training time in hours
on_demand_price = 5.76  # ml.g5.8xlarge on-demand price per hour
spot_price = 1.73  # ml.g5.8xlarge spot price per hour (approximately 30% of on-demand)

on_demand_cost = on_demand_price * estimated_hours
spot_cost = spot_price * estimated_hours

print(f"Estimated training time: {estimated_hours} hours")
print(f"On-demand cost: ${on_demand_cost:.2f}")
print(f"Spot cost: ${spot_cost:.2f}")
print(f"Cost savings with spot: ${on_demand_cost - spot_cost:.2f} ({(1 - spot_cost/on_demand_cost)*100:.1f}%)")