Skip to content

Latest commit

 

History

History
639 lines (467 loc) · 27.4 KB

smd_model_parallel_pytorch.rst

File metadata and controls

639 lines (467 loc) · 27.4 KB

PyTorch API

To use the PyTorch-specific APIs for SageMaker distributed model parallism, you need to add the following import statement at the top of your training script.

import smdistributed.modelparallel.torch as smp

Tip

Refer to Modify a PyTorch Training Script to learn how to use the following API in your PyTorch training script.

Behavior of smp.DistributedModel with Tensor Parallelism

When a model is wrapped by smp.DistributedModel, the library immediately traverses the modules of the model object, and replaces the modules that are supported for tensor parallelism with their distributed counterparts. This replacement happens in place. If there are no other references to the original modules in the script, they are garbage-collected. The module attributes that previously referred to the original submodules now refer to the distributed versions of those submodules.

Example:

# register DistributedSubmodule as the distributed version of Submodule
# (note this is a hypothetical example, smp.nn.DistributedSubmodule does not exist)
smp.tp_register_with_module(Submodule, smp.nn.DistributedSubmodule)

class MyModule(nn.Module):
    def __init__(self):
        ...

        self.submodule = Submodule()
    ...

# enabling tensor parallelism for the entire model
with smp.tensor_parallelism():
    model = MyModule()

# here model.submodule is still a Submodule object
assert isinstance(model.submodule, Submodule)

model = smp.DistributedModel(model)

# now model.submodule is replaced with an equivalent instance
# of smp.nn.DistributedSubmodule
assert isinstance(model.module.submodule, smp.nn.DistributedSubmodule)

If pipeline_parallel_degree (equivalently, partitions) is 1, the placement of model partitions into GPUs and the initial broadcast of model parameters and buffers across data-parallel ranks take place immediately. This is because it does not need to wait for the model partition when smp.DistributedModel wrapper is called. For other cases with pipeline_parallel_degree greater than 1, the broadcast and device placement will be deferred until the first call of an smp.step-decorated function happens. This is because the first smp.step-decorated function call is when the model partitioning happens if pipeline parallelism is enabled.

Because of the module replacement during the smp.DistributedModel call, any load_state_dict calls on the model, as well as any direct access to model parameters, such as during the optimizer creation, should be done after the smp.DistributedModel call.

Since the broadcast of the model parameters and buffers happens immediately during smp.DistributedModel call when the degree of pipeline parallelism is 1, using @smp.step decorators is not required when tensor parallelism is used by itself (without pipeline parallelism).

For more information about the library's tensor parallelism APIs for PyTorch, see smdmp-pytorch-tensor-parallel.

Additional Methods of smp.DistributedModel for Tensor Parallelism

The following are the new methods of smp.DistributedModel, in addition to the ones listed in the documentation.

distributed_modules()

  • An iterator that runs over the set of distributed (tensor-parallelized) modules in the model

is_distributed_parameter(param)

  • Returns True if the given nn.Parameter is distributed over tensor-parallel ranks.

is_distributed_buffer(buf)

  • Returns True if the given buffer is distributed over tensor-parallel ranks.

is_scaled_batch_parameter(param)

  • Returns True if the given nn.Parameter is operates on the scaled batch (batch over the entire TP_GROUP, and not only the local batch).

is_scaled_batch_buffer(buf)

  • Returns True if the parameter corresponding to the given buffer operates on the scaled batch (batch over the entire TP_GROUP, and not only the local batch).

default_reducer_named_parameters()

  • Returns an iterator that runs over (name, param) tuples, for param that is allreduced over the DP_GROUP.

scaled_batch_reducer_named_parameters()

  • Returns an iterator that runs over (name, param) tuples, for param that is allreduced over the RDP_GROUP.

Parameters - optimizer

An optimizer wrapper for saving/loading optimizer states. This wrapper returns optimizer with the following methods overridden:

state_dict( )

Returns the state_dict that contains optimizer state for the entire model. It first collects the local_state_dict and gathers and merges the local_state_dict from all mp_ranks to create a full state_dict.

load_state_dict( )

Same as the torch.optimizer.load_state_dict() , except:

  • It first gathers and merges the local state_dicts if they are partial.
  • The actual loading happens after the model partition so that each rank knows its local parameters.

local_state_dict( )

Returns the state_dict that contains the local optimizer state that belongs to the current mp_rank. This state_dict contains a key _smp_is_partial to indicate this is a partial state_dict, which indicates whether the state_dict contains elements corresponding to only the current partition, or to the entire model.

smp.partition(index)

Inputs

  • index (int) - The index of the partition.

A context manager which places all modules defined inside into the partition with ID index.  The index argument must be less than the number of partitions.

Use smp.partition to implement manual partitioning. If "auto_partition" is True, then the smp.partition contexts are ignored. Any module that is not placed in any smp.partition context is placed in the default_partition defined through the SageMaker Python SDK.

When smp.partition contexts are nested, the innermost context overrides the rest (see the following example). In PyTorch, manual partitioning should be done inside the module __init__, and the partition assignment applies to the modules that are created inside the smp.partition context.

Example:

class Model(torch.nn.Module):
    def __init__(self):
        with smp.partition(1):
            self.child0 = Child0()            # child0 on partition 1
            with smp.partition(2):
                self.child1 = Child1()        # child1 on partition 2
            self.child2 = Child2()            # child2 on partition 1
        self.child3 = Child3()                # child3 on default_partition

smp.get_world_process_group( )

Returns a torch.distributed ProcessGroup that consists of all processes, which can be used with the torch.distributed API. Requires "ddp": True in SageMaker Python SDK parameters.

smp.get_mp_process_group( )

Returns a torch.distributed ProcessGroup that consists of the processes in the MP_GROUP which contains the current process, which can be used with the torch.distributed API. Requires "ddp": True in SageMaker Python SDK parameters.

smp.get_dp_process_group( )

Returns a torch.distributed ProcessGroup that consists of the processes in the DP_GROUP which contains the current process, which can be used with the torch.distributed API. Requires "ddp": True in SageMaker Python SDK parameters.

smp.is_initialized( )

Returns True if smp.init has already been called for the process, and False otherwise.

smp.is_tracing( )

Returns True if the current process is running the tracing step, and False otherwise.

smp.nn.FusedLayerNorm

Apex Fused Layer Norm is currently not supported by the library. smp.nn.FusedLayerNorm replaces apex FusedLayerNorm and provides the same functionality. This requires apex to be installed on the system.

smp.optimizers.FusedNovoGrad

Fused Novo Grad optimizer is currently not supported by the library. smp.optimizers.FusedNovoGrad replaces apex FusedNovoGrad optimizer and provides the same functionality. This requires apex to be installed on the system.

smp.optimizers.FusedLamb

FusedLamb optimizer currently doesn’t work with the library. smp.optimizers.FusedLamb replaces apex FusedLamb optimizer and provides the same functionality. This requires apex to be installed on the system.

smp.amp.GradScaler

Torch AMP Gradscaler currently doesn’t work with the library. smp.amp.GradScaler replaces torch.amp.GradScaler and provides the same functionality.

APIs for Saving and Loading

smp.save( )

Saves an object. This operation is similar to torch.save(), except it has an additional keyword argument, partial, and accepts only string type for the argument f (file). If partial=True, each mp_rank saves a separate checkpoint file and the library adds an mp_rank index to your saved file.

Parameters

  • obj (dict): A saved object.
  • f (str): A string containing a file name.
  • partial (bool, default= True):  When set to True, each mp_rank saves a separate checkpoint file and the library adds an mp_rank index to the saved file. If you want to be able to load and further train a model that you save with smp.save(), you must set partial=True.
  • pickle_module (picklemodule, default = module "pickle" from "/opt/conda/lib/python3.6/pickle.py"): A module used for pickling metadata and objects.
  • pickle_protocol  (int, default=2): Can be specified to override the defaultprotocol.

smp.load( )

Loads an object saved with smp.save() from a file.

Similar to, torch.load(), except it has an additional keyword argument, partial, and accepts only string type for the argument f (file). If partial=True, then each mp_rank loads a separate checkpoint file.

Parameters

  • f (string): A string containing a file name.
  • map_location (function): A function torch.device, a string, or a dict specifying how to remap storage locations.
  • pickle_module (pickle module): A module used for unpickling metadata and objects (has to match the pickle_moduleused to serialize file).
  • pickle_load_args (Python 3 only): Optional keyword arguments passed to pickle_module.load() and pickle_module.Unpickler().
  • partial (bool, default= True): When set to True, each mp_rank loads the checkpoint corresponding to the mp_rank. Should be used when loading a model trained with the library.

General Instruction For Saving and Loading

The library can save partial or full checkpoints.

  • For partial checkpoints, each mp_rank saves its own checkpoint file with only the parameters that belong to that rank.
  • For full checkpoints, the library saves a single checkpoint that contains entire model parameters.

When saving using smp.save(), each rank only holds its own parameters. If you want to save the full model, there will be some communication between the ranks to create the full model. If you save checkpoints often, you should save partial checkpoints for best performance.

When loading using smp.load(), the library can load either partial or | full checkpoints or full checkpoints saved by a non-model-parallel model. If you want to resume training with a non-model-parallel model or do inference, you need a full checkpoint.

The following is an example of how you can save and load a checkpoint:

# Original model and optimizer
model = MyModel(...)
optimizer = MyOpt(...)

# model parallel wrapper
model = smp.DistributedModel(model)
optimizer = smp.DistributedOptimizer(optimizer)

# To save, always save on dp_rank 0 to avoid data racing
if partial:
    # To save the partial model on each mp rank
    # the library will create `checkpoint.pt_{mprank}` for each mp rank
    if save_partial_model:
        if smp.dp_rank() == 0:
            model_dict = model.local_state_dict() # save the partial model
            opt_dict = optimizer.local_state_dict() # save the partial optimizer state
            smp.save(
                {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict},
                f"/checkpoint.pt",
                partial=True,
            )

    # To save the full model
    if save_full_model:
        if smp.dp_rank() == 0:
            model_dict = model.state_dict() # save the full model
            opt_dict = optimizer.state_dict() # save the full optimizer state
            smp.save(
                {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict},
                "/checkpoint.pt",
                partial=False,
            )

# To load, load on all ranks.
# The only difference for partial/full loading is the partial flag in smp.load
# Load partial checkpoint
if partial_checkpoint:
    checkpoint = smp.load("/checkpoint.pt", partial=True)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# Load full checkpoint
if full_checkpoint:
    checkpoint = smp.load("/checkpoint.pt", partial=False)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])