# SageMaker Hyperscaler Training for Legal Reasoning Model

This notebook demonstrates how to train the Legal Reasoning Model using SageMaker Hyperscaler on ml.g5.8xlarge instances for optimal price-performance.

## Setup

First, let's set up the environment and import necessary libraries.

In [None]:
import os
import json
import yaml
import boto3
import sagemaker
from sagemaker.huggingface import HuggingFace
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Set plot style
plt.style.use('ggplot')
sns.set_theme(style="whitegrid")

## Configure AWS and SageMaker

Set up AWS credentials and SageMaker session.

In [None]:
# Set AWS region
region = 'us-east-1'  # Change to your preferred region

# Create SageMaker session
boto_session = boto3.Session(region_name=region)
sagemaker_session = sagemaker.Session(boto_session=boto_session)

# Get SageMaker execution role
role = sagemaker.get_execution_role()

# Set S3 bucket and prefix
bucket = sagemaker_session.default_bucket()
prefix = 'legal-reasoning-model'

print(f"SageMaker Role ARN: {role}")
print(f"S3 Bucket: {bucket}")
print(f"S3 Prefix: {prefix}")

## Load Configuration

Load the hyperscaler configuration from YAML file.

In [None]:
# Load configuration
config_path = "../configs/hyperscaler_config.yaml"

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Update config with our session values
config['aws']['region'] = region
config['aws']['s3_bucket'] = bucket

# Display configuration
print("Model Configuration:")
print(json.dumps(config['model'], indent=2))

print("\nTraining Configuration:")
print(json.dumps(config['training'], indent=2))

print("\nHyperscaler Configuration:")
print(json.dumps(config['hyperscaler'], indent=2))

## Prepare Training Data

Prepare and upload the training data to S3.

In [None]:
# Define paths
input_file = "../data/german/processed/all_examples.jsonl"
output_dir = "../data/hyperscaler"
language = config['model']['language']

# Check if input file exists
if not os.path.exists(input_file):
    print(f"Input file not found: {input_file}")
    print("Please run the data processing script first.")
else:
    print(f"Input file found: {input_file}")
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    print(f"Output directory: {output_dir}")
    
    # We'll use the prepare_data_for_hyperscaler.py script
    # For demonstration, we'll show the command here
    cmd = f"""python ../scripts/prepare_data_for_hyperscaler.py \
    --input-file {input_file} \
    --output-dir {output_dir} \
    --s3-bucket {bucket} \
    --s3-prefix {prefix}/data \
    --language {language}"""
    
    print("\nCommand to prepare data:")
    print(cmd)
    
    # Note: In a real notebook, you might want to run this command
    # using !{cmd} or a subprocess call

## Configure Model Parallelism

Set up the SageMaker Hyperscaler (model parallelism) configuration.

In [None]:
# Configure distribution for model parallelism
distribution = {
    "smdistributed": {
        "modelparallel": {
            "enabled": True,
            "parameters": {
                "partitions": config['hyperscaler']['model_parallel_degree'],
                "microbatches": config['hyperscaler']['microbatches'],
                "optimize": config['hyperscaler']['optimize'],
                "pipeline_parallel_degree": config['hyperscaler']['pipeline_parallel_degree'],
                "tensor_parallel_degree": config['hyperscaler']['tensor_parallel_degree'],
                "ddp": True,
                "placement_strategy": config['hyperscaler']['placement_strategy'],
                "activation_checkpointing": config['hyperscaler']['activation_checkpointing'],
                "prescaled_batch": config['hyperscaler']['prescaled_batch'],
                "shard_optimizer_state": config['hyperscaler']['shard_optimizer_state']
            }
        }
    },
    "torch_distributed": {
        "enabled": True
    }
}

print("Model Parallelism Configuration:")
print(json.dumps(distribution, indent=2))

## Set Up Hyperparameters

Configure the training hyperparameters.

In [None]:
# Set up hyperparameters
hyperparameters = {
    # Model configuration
    "model_id": config['model']['name'],
    "language": config['model']['language'],
    "max_seq_length": config['data']['max_seq_length'],
    
    # Training configuration
    "epochs": config['training']['num_train_epochs'],
    "per_device_train_batch_size": config['training']['batch_size'],
    "per_device_eval_batch_size": config['training']['batch_size'],
    "gradient_accumulation_steps": config['training']['gradient_accumulation_steps'],
    "learning_rate": config['training']['learning_rate'],
    "weight_decay": config['training']['weight_decay'],
    "warmup_steps": config['training']['warmup_steps'],
    
    # LoRA configuration
    "use_lora": str(config['training']['use_lora']).lower(),
    "lora_r": config['training']['lora_rank'],
    "lora_alpha": config['training']['lora_alpha'],
    "lora_dropout": config['training']['lora_dropout'],
    "lora_target_modules": config['training']['lora_target_modules'],
    
    # Hyperscaler configuration
    "model_parallel_degree": config['hyperscaler']['model_parallel_degree'],
    "ddp_dist_backend": "nccl",
    "fp16": str(config['training']['fp16']).lower(),
    "bf16": str(config['training']['bf16']).lower(),
    
    # Optimization
    "deepspeed_config": "ds_z3_config.json",  # Will be created in the entry point
    "torch_distributed": "true",
    
    # Checkpointing
    "save_strategy": "steps",
    "save_steps": 500,
    "save_total_limit": 2,
    
    # Evaluation
    "evaluation_strategy": "steps",
    "eval_steps": 500,
    "logging_steps": 100,
    
    # Output
    "output_dir": "/opt/ml/model"
}

print("Training Hyperparameters:")
print(json.dumps(hyperparameters, indent=2))

## Create SageMaker Estimator

Create a SageMaker HuggingFace estimator with Hyperscaler configuration.

In [None]:
# Set up training job name
import datetime
timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
job_name = f"legal-reasoning-{config['model']['language']}-{timestamp}"

# Set up instance configuration
instance_type = config['aws']['instance_type']
instance_count = config['aws']['instance_count']
use_spot = config['aws']['use_spot']
max_wait = config['aws']['max_wait'] if use_spot else None
max_run = config['aws']['max_run']

# Set up output path
output_path = f"s3://{bucket}/{prefix}/model"
checkpoint_s3_uri = f"s3://{bucket}/{prefix}/checkpoints" if use_spot else None

# Create HuggingFace estimator
huggingface_estimator = HuggingFace(
    entry_point="train_hyperscaler.py",  # Custom training script with model parallelism
    source_dir="../src/training",         # Directory containing the training code
    instance_type=instance_type,
    instance_count=instance_count,
    role=role,
    transformers_version="4.28.1",
    pytorch_version="2.0.0",
    py_version="py310",
    hyperparameters=hyperparameters,
    distribution=distribution,
    use_spot_instances=use_spot,
    max_wait=max_wait,
    max_run=max_run,
    checkpoint_s3_uri=checkpoint_s3_uri,
    output_path=output_path,
    base_job_name=job_name,
    sagemaker_session=sagemaker_session
)

print(f"Job Name: {job_name}")
print(f"Instance Type: {instance_type}")
print(f"Using Spot Instances: {use_spot}")
print(f"Output Path: {output_path}")