# Legal Reasoning Model Training

This notebook demonstrates how to train the Legal Reasoning Model using SageMaker.

In [None]:
import os
import yaml
import boto3
import sagemaker
from sagemaker.huggingface import HuggingFace

## Load Configuration

In [None]:
# Load configuration
with open('../configs/default_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded:")
print(f"Model: {config['model']['name']}")
print(f"Language: {config['model']['language']}")
print(f"Training with LoRA: {config['training']['use_lora']}")

## Set Up SageMaker Session

In [None]:
# Set up SageMaker session
region = config['aws']['region']
boto_session = boto3.Session(region_name=region)
sagemaker_session = sagemaker.Session(boto_session=boto_session)
role = sagemaker.get_execution_role()

# S3 bucket for data and model artifacts
bucket = config['aws']['s3_bucket']
prefix = config['aws']['s3_prefix']

print(f"SageMaker session initialized in region: {region}")
print(f"S3 bucket: {bucket}")
print(f"S3 prefix: {prefix}")

## Prepare Training Script

In [None]:
# Create source directory for training code
source_dir = "../src"

# Entry point script
entry_point = "training/train_sagemaker.py"

print(f"Source directory: {source_dir}")
print(f"Entry point: {entry_point}")

## Configure SageMaker Estimator

In [None]:
# Configure hyperparameters
hyperparameters = {
    'model-name': config['model']['name'],
    'language': config['model']['language'],
    'max-seq-length': config['data']['max_seq_length'],
    'batch-size': config['training']['batch_size'],
    'learning-rate': config['training']['learning_rate'],
    'epochs': config['training']['num_train_epochs'],
    'use-lora': str(config['training']['use_lora']).lower(),
    'lora-rank': config['training']['lora_rank'],
    'lora-alpha': config['training']['lora_alpha'],
    'lora-dropout': config['training']['lora_dropout'],
    'output-dir': '/opt/ml/model'
}

# Create HuggingFace estimator
estimator = HuggingFace(
    entry_point=entry_point,
    source_dir=source_dir,
    role=role,
    transformers_version='4.28.1',
    pytorch_version='2.0.0',
    py_version='py310',
    instance_count=config['aws']['instance_count'],
    instance_type=config['aws']['instance_type'],
    hyperparameters=hyperparameters,
    output_path=f"s3://{bucket}/{prefix}/output"
)

print("SageMaker estimator configured")

## Start Training Job

In [None]:
# Define data channels
data_channels = {
    'train': f"s3://{bucket}/{prefix}/train",
    'validation': f"s3://{bucket}/{prefix}/validation",
    'test': f"s3://{bucket}/{prefix}/test"
}

print("Data channels:")
for channel, path in data_channels.items():
    print(f"  {channel}: {path}")

In [None]:
# Start training job
# Uncomment to run the training job
# estimator.fit(data_channels)

print("Training job would be started with the above configuration.")
print("Uncomment the estimator.fit() line to actually start the training job.")

## Monitor Training Progress

In [None]:
# After training is complete, you can deploy the model
# predictor = estimator.deploy(
#     initial_instance_count=1,
#     instance_type='ml.g5.2xlarge',
#     endpoint_name='legal-reasoning-endpoint'
# )

print("After training, you can deploy the model using the code above.")