# SageMaker Hyperscaler Training for Legal Reasoning Model (Part 2)

This notebook demonstrates how to train the Legal Reasoning Model using SageMaker Hyperscaler on ml.g5.8xlarge instances for optimal price-performance.

## Part 2: Model Parallelism Configuration

### Setup

First, let's import the necessary libraries and load the configuration.

In [None]:
import os
import json
import yaml
import boto3
import sagemaker
from sagemaker.huggingface import HuggingFace

# Load configuration
config_path = "../configs/hyperscaler_config.yaml"

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Set AWS region
region = 'us-east-1'  # Change to your preferred region

# Create SageMaker session
boto_session = boto3.Session(region_name=region)
sagemaker_session = sagemaker.Session(boto_session=boto_session)
role = sagemaker.get_execution_role()
bucket = sagemaker_session.default_bucket()
prefix = 'legal-reasoning-model'

### Configure Model Parallelism

Set up the SageMaker Hyperscaler (model parallelism) configuration.

In [None]:
# Configure distribution for model parallelism
distribution = {
    "smdistributed": {
        "modelparallel": {
            "enabled": True,
            "parameters": {
                "partitions": config['hyperscaler']['model_parallel_degree'],
                "microbatches": config['hyperscaler']['microbatches'],
                "optimize": config['hyperscaler']['optimize'],
                "pipeline_parallel_degree": config['hyperscaler']['pipeline_parallel_degree'],
                "tensor_parallel_degree": config['hyperscaler']['tensor_parallel_degree'],
                "ddp": True,
                "placement_strategy": config['hyperscaler']['placement_strategy'],
                "activation_checkpointing": config['hyperscaler']['activation_checkpointing'],
                "prescaled_batch": config['hyperscaler']['prescaled_batch'],
                "shard_optimizer_state": config['hyperscaler']['shard_optimizer_state']
            }
        }
    },
    "torch_distributed": {
        "enabled": True
    }
}

print("Model Parallelism Configuration:")
print(json.dumps(distribution, indent=2))

### Set Up Hyperparameters

Configure the training hyperparameters.

In [None]:
# Set up hyperparameters
hyperparameters = {
    # Model configuration
    "model_id": config['model']['name'],
    "language": config['model']['language'],
    "max_seq_length": config['data']['max_seq_length'],
    
    # Training configuration
    "epochs": config['training']['num_train_epochs'],
    "per_device_train_batch_size": config['training']['batch_size'],
    "per_device_eval_batch_size": config['training']['batch_size'],
    "gradient_accumulation_steps": config['training']['gradient_accumulation_steps'],
    "learning_rate": config['training']['learning_rate'],
    "weight_decay": config['training']['weight_decay'],
    "warmup_steps": config['training']['warmup_steps'],
    
    # LoRA configuration
    "use_lora": str(config['training']['use_lora']).lower(),
    "lora_r": config['training']['lora_rank'],
    "lora_alpha": config['training']['lora_alpha'],
    "lora_dropout": config['training']['lora_dropout'],
    "lora_target_modules": config['training']['lora_target_modules'],
    
    # Hyperscaler configuration
    "model_parallel_degree": config['hyperscaler']['model_parallel_degree'],
    "ddp_dist_backend": "nccl",
    "fp16": str(config['training']['fp16']).lower(),
    "bf16": str(config['training']['bf16']).lower(),
    
    # Optimization
    "deepspeed_config": "ds_z3_config.json",  # Will be created in the entry point
    "torch_distributed": "true",
    
    # Checkpointing
    "save_strategy": "steps",
    "save_steps": 500,
    "save_total_limit": 2,
    
    # Evaluation
    "evaluation_strategy": "steps",
    "eval_steps": 500,
    "logging_steps": 100,
    
    # Output
    "output_dir": "/opt/ml/model"
}

print("Training Hyperparameters:")
print(json.dumps(hyperparameters, indent=2))