-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
Describe the bug
A clear and concise description of what the bug is.
sagemaker pytorch distributeddataparallel doesn't support model.no_sync().
I'm not able to find how you wrap around the distributeddataparallel, but it seems this no_sync context manager is not implemented.
To reproduce
A clear, step-by-step set of instructions to reproduce the bug.
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
assume model is a pytorch model
model = DDP(model)
with model.no_sync(): # this line causes error
pass
Expected behavior
A clear and concise description of what you expected to happen.
A DDP-wrapped model should be able to call no_sync(). Under grad accumulation, it can save quite some computational budget (~20% per huggingface/transformers#7742).
I see you're pushing out smdistributed 1.1.0. Not sure if this no_sync is included.
Screenshots or logs
If applicable, add screenshots or logs to help explain your problem.
System information
A description of your system. Please provide:
- SageMaker Python SDK version: 2.24.5
- Framework name (eg. PyTorch) or algorithm (eg. KMeans): pytorch
- Framework version: 1.7.1
- Python version: 3.6
- CPU or GPU: GPU
- Custom Docker image (Y/N): N
Additional context
I'm using this docker image: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.7.1-gpu-py36-cu110-ubuntu18.04