To use the PyTorch-specific APIs for SageMaker distributed model parallism, you need to add the following import statement at the top of your training script.
import smdistributed.modelparallel.torch as smp
Tip
Refer to Modify a PyTorch Training Script to learn how to use the following API in your PyTorch training script.
Behavior of
smp.DistributedModel
with Tensor ParallelismWhen a model is wrapped by
smp.DistributedModel
, the library immediately traverses the modules of the model object, and replaces the modules that are supported for tensor parallelism with their distributed counterparts. This replacement happens in place. If there are no other references to the original modules in the script, they are garbage-collected. The module attributes that previously referred to the original submodules now refer to the distributed versions of those submodules.Example:
# register DistributedSubmodule as the distributed version of Submodule # (note this is a hypothetical example, smp.nn.DistributedSubmodule does not exist) smp.tp_register_with_module(Submodule, smp.nn.DistributedSubmodule) class MyModule(nn.Module): def __init__(self): ... self.submodule = Submodule() ... # enabling tensor parallelism for the entire model with smp.tensor_parallelism(): model = MyModule() # here model.submodule is still a Submodule object assert isinstance(model.submodule, Submodule) model = smp.DistributedModel(model) # now model.submodule is replaced with an equivalent instance # of smp.nn.DistributedSubmodule assert isinstance(model.module.submodule, smp.nn.DistributedSubmodule)If
pipeline_parallel_degree
(equivalently,partitions
) is 1, the placement of model partitions into GPUs and the initial broadcast of model parameters and buffers across data-parallel ranks take place immediately. This is because it does not need to wait for the model partition whensmp.DistributedModel
wrapper is called. For other cases withpipeline_parallel_degree
greater than 1, the broadcast and device placement will be deferred until the first call of ansmp.step
-decorated function happens. This is because the firstsmp.step
-decorated function call is when the model partitioning happens if pipeline parallelism is enabled.Because of the module replacement during the
smp.DistributedModel
call, anyload_state_dict
calls on the model, as well as any direct access to model parameters, such as during the optimizer creation, should be done after thesmp.DistributedModel
call.Since the broadcast of the model parameters and buffers happens immediately during
smp.DistributedModel
call when the degree of pipeline parallelism is 1, using@smp.step
decorators is not required when tensor parallelism is used by itself (without pipeline parallelism).For more information about the library's tensor parallelism APIs for PyTorch, see
smdmp-pytorch-tensor-parallel
.Additional Methods of
smp.DistributedModel
for Tensor ParallelismThe following are the new methods of
smp.DistributedModel
, in addition to the ones listed in the documentation.distributed_modules()
- An iterator that runs over the set of distributed (tensor-parallelized) modules in the model
is_distributed_parameter(param)
- Returns
True
if the givennn.Parameter
is distributed over tensor-parallel ranks.is_distributed_buffer(buf)
- Returns
True
if the given buffer is distributed over tensor-parallel ranks.is_scaled_batch_parameter(param)
- Returns
True
if the givennn.Parameter
is operates on the scaled batch (batch over the entireTP_GROUP
, and not only the local batch).is_scaled_batch_buffer(buf)
- Returns
True
if the parameter corresponding to the given buffer operates on the scaled batch (batch over the entireTP_GROUP
, and not only the local batch).default_reducer_named_parameters()
- Returns an iterator that runs over
(name, param)
tuples, forparam
that is allreduced over theDP_GROUP
.scaled_batch_reducer_named_parameters()
- Returns an iterator that runs over
(name, param)
tuples, forparam
that is allreduced over theRDP_GROUP
.
Parameters - optimizer
An optimizer wrapper for saving/loading optimizer states. This wrapper returns optimizer
with the following methods overridden:
state_dict( )
Returns the state_dict
that contains optimizer state for the entire model. It first collects the local_state_dict
and gathers and merges the local_state_dict
from all mp_rank
s to create a full state_dict
.
load_state_dict( )
Same as the torch.optimizer.load_state_dict()
, except:
- It first gathers and merges the local
state_dict
s if they are partial.- The actual loading happens after the model partition so that each rank knows its local parameters.
local_state_dict( )
Returns the state_dict
that contains the local optimizer state that belongs to the current mp_rank
. This state_dict
contains a key _smp_is_partial
to indicate this is a partial state_dict
, which indicates whether the state_dict
contains elements corresponding to only the current partition, or to the entire model.
smp.partition(index)
Inputs
index
(int) - The index of the partition.
A context manager which places all modules defined inside into the partition with ID index
. The index
argument must be less than the number of partitions.
Use smp.partition
to implement manual partitioning. If "auto_partition"
is True
, then the smp.partition
contexts are ignored. Any module that is not placed in any smp.partition
context is placed in the default_partition
defined through the SageMaker Python SDK.
When smp.partition
contexts are nested, the innermost context overrides the rest (see the following example). In PyTorch, manual partitioning should be done inside the module __init__
, and the partition assignment applies to the modules that are created inside the smp.partition
context.
Example:
class Model(torch.nn.Module):
def __init__(self):
with smp.partition(1):
self.child0 = Child0() # child0 on partition 1
with smp.partition(2):
self.child1 = Child1() # child1 on partition 2
self.child2 = Child2() # child2 on partition 1
self.child3 = Child3() # child3 on default_partition
smp.get_world_process_group( )
Returns a torch.distributed
ProcessGroup
that consists of all processes, which can be used with the torch.distributed
API. Requires "ddp": True
in SageMaker Python SDK parameters.
smp.get_mp_process_group( )
Returns a torch.distributed
ProcessGroup
that consists of the processes in the MP_GROUP
which contains the current process, which can be used with the torch.distributed
API. Requires "ddp": True
in SageMaker Python SDK parameters.
smp.get_dp_process_group( )
Returns a torch.distributed
ProcessGroup
that consists of the processes in the DP_GROUP
which contains the current process, which can be used with the torch.distributed
API. Requires "ddp": True
in SageMaker Python SDK parameters.
smp.is_initialized( )
Returns True
if smp.init
has already been called for the process, and False
otherwise.
smp.is_tracing( )
Returns True
if the current process is running the tracing step, and False
otherwise.
smp.nn.FusedLayerNorm
Apex Fused Layer Norm is currently not supported by the library. smp.nn.FusedLayerNorm
replaces apex
FusedLayerNorm
and provides the same functionality. This requires apex
to be installed on the system.
smp.optimizers.FusedNovoGrad
Fused Novo Grad optimizer is currently not supported by the library. smp.optimizers.FusedNovoGrad
replaces apex
FusedNovoGrad
optimizer and provides the same functionality. This requires apex
to be installed on the system.
smp.optimizers.FusedLamb
FusedLamb optimizer currently doesn’t work with the library. smp.optimizers.FusedLamb
replaces apex
FusedLamb
optimizer and provides the same functionality. This requires apex
to be installed on the system.
smp.amp.GradScaler
Torch AMP Gradscaler currently doesn’t work with the library. smp.amp.GradScaler
replaces torch.amp.GradScaler
and provides the same functionality.
smp.save( )
Saves an object. This operation is similar to torch.save()
, except it has an additional keyword argument, partial
, and accepts only string type for the argument f
(file). If partial=True
, each mp_rank
saves a separate checkpoint file and the library adds an mp_rank
index to your saved file.
Parameters
obj
(dict): A saved object.f
(str): A string containing a file name.partial
(bool, default=True
): When set toTrue
, eachmp_rank
saves a separate checkpoint file and the library adds anmp_rank
index to the saved file. If you want to be able to load and further train a model that you save withsmp.save()
, you must setpartial=True
.pickle_module
(picklemodule, default = module"pickle"
from"/opt/conda/lib/python3.6/pickle.py"
): A module used for pickling metadata and objects.pickle_protocol
(int, default=2): Can be specified to override the defaultprotocol.
smp.load( )
Loads an object saved with smp.save()
from a file.
Similar to, torch.load(), except it has an additional keyword argument, partial
, and accepts only string type for the argument f
(file). If partial=True
, then each mp_rank
loads a separate checkpoint file.
Parameters
f
(string): A string containing a file name.map_location
(function): A function torch.device, a string, or a dict specifying how to remap storage locations.pickle_module
(pickle module): A module used for unpickling metadata and objects (has to match thepickle_module
used to serialize file).pickle_load_args
(Python 3 only): Optional keyword arguments passed topickle_module.load()
andpickle_module.Unpickler()
.partial
(bool, default=True
): When set toTrue
, eachmp_rank
loads the checkpoint corresponding to themp_rank
. Should be used when loading a model trained with the library.
The library can save partial or full checkpoints.
- For partial checkpoints, each
mp_rank
saves its own checkpoint file with only the parameters that belong to that rank. - For full checkpoints, the library saves a single checkpoint that contains entire model parameters.
When saving using smp.save()
, each rank only holds its own parameters. If you want to save the full model, there will be some communication between the ranks to create the full model. If you save checkpoints often, you should save partial checkpoints for best performance.
When loading using smp.load()
, the library can load either partial or | full checkpoints or full checkpoints saved by a non-model-parallel model. If you want to resume training with a non-model-parallel model or do inference, you need a full checkpoint.
The following is an example of how you can save and load a checkpoint:
# Original model and optimizer
model = MyModel(...)
optimizer = MyOpt(...)
# model parallel wrapper
model = smp.DistributedModel(model)
optimizer = smp.DistributedOptimizer(optimizer)
# To save, always save on dp_rank 0 to avoid data racing
if partial:
# To save the partial model on each mp rank
# the library will create `checkpoint.pt_{mprank}` for each mp rank
if save_partial_model:
if smp.dp_rank() == 0:
model_dict = model.local_state_dict() # save the partial model
opt_dict = optimizer.local_state_dict() # save the partial optimizer state
smp.save(
{"model_state_dict": model_dict, "optimizer_state_dict": opt_dict},
f"/checkpoint.pt",
partial=True,
)
# To save the full model
if save_full_model:
if smp.dp_rank() == 0:
model_dict = model.state_dict() # save the full model
opt_dict = optimizer.state_dict() # save the full optimizer state
smp.save(
{"model_state_dict": model_dict, "optimizer_state_dict": opt_dict},
"/checkpoint.pt",
partial=False,
)
# To load, load on all ranks.
# The only difference for partial/full loading is the partial flag in smp.load
# Load partial checkpoint
if partial_checkpoint:
checkpoint = smp.load("/checkpoint.pt", partial=True)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# Load full checkpoint
if full_checkpoint:
checkpoint = smp.load("/checkpoint.pt", partial=False)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])