# FSDP Multi-GPU Training on SageMaker

Train medical image segmentation models with Fully Sharded Data Parallel (FSDP) across multiple GPUs.

## Setup

In [1]:
import sagemaker
from sagemaker.pytorch.estimator import PyTorch
from sagemaker import get_execution_role
import boto3

sagemaker_session = sagemaker.Session(boto3.Session(region_name='us-east-1'))



# sagemaker_session = sagemaker.Session()
# role = get_execution_role()
role="AmazonSageMaker-ExecutionRole-20240907T181142"

region = sagemaker_session.boto_region_name
bucket = sagemaker_session.default_bucket()

print(f"Region: {region}")
print(f"Bucket: {bucket}")

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ubuntu/.config/sagemaker/config.yaml
Region: us-east-1
Bucket: sagemaker-us-east-1-575108919340


## Data Path

Point to your S3 data location:

In [2]:
bucket = 'public-datasets-imaging-us-east-1'
data_path = f's3://{bucket}/segmentation_data/'
output_path = f's3://{bucket}/segmentation_data/output'

print(f"Training data: {data_path}")
print(f"Output path: {output_path}")

Training data: s3://public-datasets-imaging-us-east-1/segmentation_data/
Output path: s3://public-datasets-imaging-us-east-1/segmentation_data/output


## Configure Training Job

In [3]:
# Hyperparameters
hyperparameters = {
    'model_name': 'SegResNet',
    'batch_size': 2,
    'epochs': 10,
    'lr': 0.0001
}

# PyTorch Estimator with FSDP
estimator = PyTorch(
    entry_point='train_fsdp.py',
    source_dir='../code/training',
    role=role,
    instance_type='ml.g4dn.12xlarge',  # 4 GPUs
    instance_count=1,
    framework_version='2.0.0',
    py_version='py310',
    hyperparameters=hyperparameters,
    distribution={
        'pytorchddp': {
            'enabled': True
        }
    },
    keep_alive_period_in_seconds=600,
    disable_profiler=True,
    debugger_hook_config=False,
    sagemaker_session=sagemaker_session,
)

print("Estimator configured successfully")

Estimator configured successfully


## Start Training

In [4]:
# Start training job
estimator.fit({'training': data_path}, wait=True, logs='All')

Using provided s3_resource


INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: pytorch-training-2026-01-25-20-44-53-697


2026-01-25 20:44:54 Starting - Starting the training job...
2026-01-25 20:45:23 Pending - Training job waiting for capacity...
2026-01-25 20:45:48 Pending - Preparing the instances for training...
2026-01-25 20:46:17 Downloading - Downloading input data...
2026-01-25 20:46:42 Downloading - Downloading the training image..................
2026-01-25 20:49:54 Training - Training image download completed. Training in progress.......bash: cannot set terminal process group (-1): Inappropriate ioctl for device
bash: no job control in this shell
2026-01-25 20:50:45,759 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training
2026-01-25 20:50:45,794 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)
2026-01-25 20:50:45,803 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.
2026-01-25 20:50:45,805 sagemaker_pytorch_container.training INFO     Invoking SMDataParallel for native PT DDP jo

## Get Model Artifacts

In [5]:
# Model artifacts location
model_data = estimator.model_data
print(f"Model artifacts: {model_data}")

Model artifacts: s3://sagemaker-us-east-1-575108919340/pytorch-training-2026-01-25-20-44-53-697/output/model.tar.gz


## Training Metrics

View training metrics in CloudWatch or download TensorBoard logs from S3.

In [6]:
# Get training job name
training_job_name = estimator.latest_training_job.name
print(f"Training job: {training_job_name}")

# CloudWatch logs
print(f"\nCloudWatch logs:")
print(f"https://console.aws.amazon.com/cloudwatch/home?region={region}#logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252FTrainingJobs")

Training job: pytorch-training-2026-01-25-20-44-53-697

CloudWatch logs:
https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252FTrainingJobs
