From 1e80bbb5d33a18bf0665ee6d540d3b4097f5d34b Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Thu, 7 Aug 2025 00:16:27 +0200 Subject: [PATCH 01/14] feat: implemented stage FQN generation for pipeline parallelism --- src/modalities/models/parallelism/__init__.py | 0 .../parallelism/pipeline_parallelism.py | 88 +++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 src/modalities/models/parallelism/__init__.py create mode 100644 src/modalities/models/parallelism/pipeline_parallelism.py diff --git a/src/modalities/models/parallelism/__init__.py b/src/modalities/models/parallelism/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/modalities/models/parallelism/pipeline_parallelism.py b/src/modalities/models/parallelism/pipeline_parallelism.py new file mode 100644 index 000000000..e1d2233ba --- /dev/null +++ b/src/modalities/models/parallelism/pipeline_parallelism.py @@ -0,0 +1,88 @@ +# Some portions of this implementation are inspired and/or adapted +# from Meta's open-source project TorchTitan, +# licensed under the BSD 3-Clause License. + +import math +from abc import ABC, abstractmethod +from typing import Optional + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.pipelining.schedules import PipelineScheduleSingle, get_schedule_class + +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees + + +class FQNsPerStageGenerator(ABC): + @abstractmethod + def generate_fqns_per_stage( + self, num_stages: int, num_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1 + ) -> list[list[str]]: + """ + Generate a list of fully qualified names (FQNs) for each pipeline stage. + + Args: + num_stages (int): Number of stages in the pipeline. + num_layers (int): Total number of layers in the model. + input_layer_equivalence (int): Determines to how many transformer layers + the input layer corresponds. Default is 1. + output_layer_equivalence (int): Determines to how many transformer layers + the output layer corresponds. Default is 1. + + Returns: + list[list[str]]: A list containing an FQN list for each stage. + """ + raise NotImplementedError("This method should be implemented by subclasses.") + + +class PipelineFactory: + """Pipeline factory class to create pipelined models.""" + + @staticmethod + def create_pipeline_model( + num_layers: int, + fqns_per_stage_generator: FQNsPerStageGenerator, + device_mesh: DeviceMesh, + pp_schedule_name: str, + num_layers_per_stage: int, + input_layer_equivalence: Optional[int] = 1, + output_layer_equivalence: Optional[int] = 1, + ) -> torch.nn.Module: + device_mesh[ParallelismDegrees.PP.value] + pp_dims = device_mesh.size(ParallelismDegrees.PP.value) + schedule_class = get_schedule_class(pp_schedule_name) + is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) + if not is_single_stage_schedule: + raise ValueError( + f"Unsupported pipeline schedule: {pp_schedule_name}. We only support single-stage schedules." + ) + + # calculate the number of stages + num_virtual_stages = math.ceil( + (num_layers + input_layer_equivalence + output_layer_equivalence) / num_layers_per_stage + ) + if num_virtual_stages % pp_dims != 0: + raise ValueError( + f"Number of virtual stages {num_virtual_stages} is not divisible by parallel dimensions {pp_dims}. " + f"For reference: {num_layers=} {input_layer_equivalence=} " + f"{output_layer_equivalence=} {num_layers_per_stage=}" + ) + + stages_per_rank = num_virtual_stages // pp_dims + if stages_per_rank != 1: + raise ValueError( + f"Stages per rank {stages_per_rank} must be 1 for single-stage schedules. " + f"Please adjust {num_layers_per_stage=} to ensure each PP rank has exactly one stage." + ) + + fqns_per_stage_generator.generate_fqns_per_stage( + num_stages=num_virtual_stages, + num_layers=num_layers, + input_layer_equivalence=input_layer_equivalence, + output_layer_equivalence=output_layer_equivalence, + ) + + @staticmethod + def create_gpt2_model_splitter(): + """Create a GPT-2 model splitter for pipeline parallelism.""" + pass From ed93d284d108a1008b0eb7dae102f3b4a98f4c46 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Thu, 7 Aug 2025 15:49:13 +0200 Subject: [PATCH 02/14] feat: added FQNs per stage calculation --- .../parallelism/pipeline_parallelism.py | 91 +++++++++++++++++-- 1 file changed, 83 insertions(+), 8 deletions(-) diff --git a/src/modalities/models/parallelism/pipeline_parallelism.py b/src/modalities/models/parallelism/pipeline_parallelism.py index e1d2233ba..ac6be437c 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism.py +++ b/src/modalities/models/parallelism/pipeline_parallelism.py @@ -1,4 +1,4 @@ -# Some portions of this implementation are inspired and/or adapted +# Some portions of this implementation are inspired, adapted, or refactored # from Meta's open-source project TorchTitan, # licensed under the BSD 3-Clause License. @@ -14,27 +14,102 @@ class FQNsPerStageGenerator(ABC): - @abstractmethod def generate_fqns_per_stage( self, num_stages: int, num_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1 ) -> list[list[str]]: """ - Generate a list of fully qualified names (FQNs) for each pipeline stage. + Generate FQNs for each stage in a GPT-2 model. Args: num_stages (int): Number of stages in the pipeline. num_layers (int): Total number of layers in the model. - input_layer_equivalence (int): Determines to how many transformer layers - the input layer corresponds. Default is 1. - output_layer_equivalence (int): Determines to how many transformer layers - the output layer corresponds. Default is 1. + input_layer_equivalence (int): Number of layers corresponding to the input layer. + output_layer_equivalence (int): Number of layers corresponding to the output layer. Returns: - list[list[str]]: A list containing an FQN list for each stage. + list[list[str]]: A list containing FQNs for each stage. + """ + + # Potential split points for GPT-2 model with each potential split point + # listing the FQNs of the modules in that stage and the computational weight. + # The computational weight of the input and output modules are estimated + # based on the number of layers they correspond to. + potential_split_points = self._get_potential_split_points( + num_layers=num_layers, + input_layer_equivalence=input_layer_equivalence, + output_layer_equivalence=output_layer_equivalence, + ) + # Calculate the weight per stage based on the total weight and number of stages + weight_per_stage = math.ceil(sum(weight for _, weight in potential_split_points) / num_stages) + # pack the stages with the layers + next_split_point = 0 + module_names_per_stage: list[list[str]] = [] + for _ in range(num_stages): + stage_fqns = [] + stage_weight = 0 + while next_split_point < len(potential_split_points): + fqns, weight = potential_split_points[next_split_point] + if weight > weight_per_stage: + raise ValueError( + f"Weight of {weight} for {fqns} exceeds weight per stage {weight_per_stage}. " + "Please adjust the number of stages or the weight distribution." + ) + if stage_weight + weight > weight_per_stage: + break + stage_fqns.extend(fqns) + stage_weight += weight + next_split_point += 1 + module_names_per_stage.append(stage_fqns) + + return module_names_per_stage + + @abstractmethod + def _get_potential_split_points( + self, num_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1 + ) -> list[tuple[list[str], int]]: + """ + Returns a list of potential split points for the GPT-2 model. + + Args: + num_layers (int): Total number of layers in the model. + input_layer_equivalence (int): Number of layers corresponding to the input layer. + output_layer_equivalence (int): Number of layers corresponding to the output layer. + + Returns: + list[tuple[list[str], int]]: A list containing tuples of FQNs and their computational weights. """ raise NotImplementedError("This method should be implemented by subclasses.") +class GPT2LLMFQNsPerStageGenerator(FQNsPerStageGenerator): + def _get_potential_split_points( + self, num_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1 + ) -> list[tuple[list[str], int]]: + """ + Returns a list of potential split points for the GPT-2 model. + + Args: + num_layers (int): Total number of layers in the model. + input_layer_equivalence (int): Number of layers corresponding to the input layer. + output_layer_equivalence (int): Number of layers corresponding to the output layer. + + Returns: + list[tuple[list[str], int]]: A list containing tuples of FQNs and their computational weights. + """ + + # Potential split points for GPT-2 model with each potential split point + # listing the FQNs of the modules in that stage and the computational weight. + # The computational weight of the input and output modules are estimated + # based on the number of layers they correspond to. + potential_split_points = [ + (["transformer.wte", "transformer.wpe"], input_layer_equivalence), + *[([f"transformer.h.{i}"], 1) for i in range(num_layers)], + (["transformer.lm_head_norm", "transformer.lm_head"], output_layer_equivalence), + ] + + return potential_split_points + + class PipelineFactory: """Pipeline factory class to create pipelined models.""" From 6241ea8dbfef66ac7b89e07a0ee350ddb37306a5 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 15 Aug 2025 15:31:34 +0200 Subject: [PATCH 03/14] feat: generic FQN-based PP staging --- .../parallelism/pipeline_parallelism.py | 279 +++++++++--------- 1 file changed, 145 insertions(+), 134 deletions(-) diff --git a/src/modalities/models/parallelism/pipeline_parallelism.py b/src/modalities/models/parallelism/pipeline_parallelism.py index ac6be437c..0af1551af 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism.py +++ b/src/modalities/models/parallelism/pipeline_parallelism.py @@ -2,162 +2,173 @@ # from Meta's open-source project TorchTitan, # licensed under the BSD 3-Clause License. -import math -from abc import ABC, abstractmethod -from typing import Optional +import copy +from typing import Any, Optional, Type import torch +import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.pipelining import PipelineStage from torch.distributed.pipelining.schedules import PipelineScheduleSingle, get_schedule_class +from modalities.models.parallelism.stages_generator import StagesGenerator from modalities.running_env.fsdp.device_mesh import ParallelismDegrees -class FQNsPerStageGenerator(ABC): - def generate_fqns_per_stage( - self, num_stages: int, num_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1 - ) -> list[list[str]]: - """ - Generate FQNs for each stage in a GPT-2 model. - - Args: - num_stages (int): Number of stages in the pipeline. - num_layers (int): Total number of layers in the model. - input_layer_equivalence (int): Number of layers corresponding to the input layer. - output_layer_equivalence (int): Number of layers corresponding to the output layer. - - Returns: - list[list[str]]: A list containing FQNs for each stage. - """ - - # Potential split points for GPT-2 model with each potential split point - # listing the FQNs of the modules in that stage and the computational weight. - # The computational weight of the input and output modules are estimated - # based on the number of layers they correspond to. - potential_split_points = self._get_potential_split_points( - num_layers=num_layers, - input_layer_equivalence=input_layer_equivalence, - output_layer_equivalence=output_layer_equivalence, - ) - # Calculate the weight per stage based on the total weight and number of stages - weight_per_stage = math.ceil(sum(weight for _, weight in potential_split_points) / num_stages) - # pack the stages with the layers - next_split_point = 0 - module_names_per_stage: list[list[str]] = [] - for _ in range(num_stages): - stage_fqns = [] - stage_weight = 0 - while next_split_point < len(potential_split_points): - fqns, weight = potential_split_points[next_split_point] - if weight > weight_per_stage: - raise ValueError( - f"Weight of {weight} for {fqns} exceeds weight per stage {weight_per_stage}. " - "Please adjust the number of stages or the weight distribution." - ) - if stage_weight + weight > weight_per_stage: - break - stage_fqns.extend(fqns) - stage_weight += weight - next_split_point += 1 - module_names_per_stage.append(stage_fqns) - - return module_names_per_stage - - @abstractmethod - def _get_potential_split_points( - self, num_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1 - ) -> list[tuple[list[str], int]]: - """ - Returns a list of potential split points for the GPT-2 model. - - Args: - num_layers (int): Total number of layers in the model. - input_layer_equivalence (int): Number of layers corresponding to the input layer. - output_layer_equivalence (int): Number of layers corresponding to the output layer. - - Returns: - list[tuple[list[str], int]]: A list containing tuples of FQNs and their computational weights. - """ - raise NotImplementedError("This method should be implemented by subclasses.") - - -class GPT2LLMFQNsPerStageGenerator(FQNsPerStageGenerator): - def _get_potential_split_points( - self, num_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1 - ) -> list[tuple[list[str], int]]: - """ - Returns a list of potential split points for the GPT-2 model. - - Args: - num_layers (int): Total number of layers in the model. - input_layer_equivalence (int): Number of layers corresponding to the input layer. - output_layer_equivalence (int): Number of layers corresponding to the output layer. - - Returns: - list[tuple[list[str], int]]: A list containing tuples of FQNs and their computational weights. - """ - - # Potential split points for GPT-2 model with each potential split point - # listing the FQNs of the modules in that stage and the computational weight. - # The computational weight of the input and output modules are estimated - # based on the number of layers they correspond to. - potential_split_points = [ - (["transformer.wte", "transformer.wpe"], input_layer_equivalence), - *[([f"transformer.h.{i}"], 1) for i in range(num_layers)], - (["transformer.lm_head_norm", "transformer.lm_head"], output_layer_equivalence), - ] - - return potential_split_points - - class PipelineFactory: """Pipeline factory class to create pipelined models.""" @staticmethod - def create_pipeline_model( - num_layers: int, - fqns_per_stage_generator: FQNsPerStageGenerator, + def get_pipelined_model( + whole_model: nn.Module, + stages_generator: StagesGenerator, device_mesh: DeviceMesh, + local_rank: int, pp_schedule_name: str, num_layers_per_stage: int, - input_layer_equivalence: Optional[int] = 1, - output_layer_equivalence: Optional[int] = 1, ) -> torch.nn.Module: - device_mesh[ParallelismDegrees.PP.value] - pp_dims = device_mesh.size(ParallelismDegrees.PP.value) + device = torch.device("cuda", local_rank) + pp_dims = device_mesh[ParallelismDegrees.PP.value].size() + + fqns_per_stage = stages_generator.get_stages( + num_layers_per_stage=num_layers_per_stage, + pp_dims=pp_dims, + ) + + pp_mesh = device_mesh[ParallelismDegrees.PP.value] schedule_class = get_schedule_class(pp_schedule_name) is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) if not is_single_stage_schedule: raise ValueError( f"Unsupported pipeline schedule: {pp_schedule_name}. We only support single-stage schedules." ) - - # calculate the number of stages - num_virtual_stages = math.ceil( - (num_layers + input_layer_equivalence + output_layer_equivalence) / num_layers_per_stage - ) - if num_virtual_stages % pp_dims != 0: - raise ValueError( - f"Number of virtual stages {num_virtual_stages} is not divisible by parallel dimensions {pp_dims}. " - f"For reference: {num_layers=} {input_layer_equivalence=} " - f"{output_layer_equivalence=} {num_layers_per_stage=}" - ) - - stages_per_rank = num_virtual_stages // pp_dims - if stages_per_rank != 1: - raise ValueError( - f"Stages per rank {stages_per_rank} must be 1 for single-stage schedules. " - f"Please adjust {num_layers_per_stage=} to ensure each PP rank has exactly one stage." - ) - - fqns_per_stage_generator.generate_fqns_per_stage( - num_stages=num_virtual_stages, - num_layers=num_layers, - input_layer_equivalence=input_layer_equivalence, - output_layer_equivalence=output_layer_equivalence, + stage, model = PipelineFactory._get_split_model( + whole_model=whole_model, + schedule_class=schedule_class, + pp_mesh=pp_mesh, + device=device, + fqns_per_stage=fqns_per_stage, ) + return whole_model # TODO return pipelined model @staticmethod - def create_gpt2_model_splitter(): - """Create a GPT-2 model splitter for pipeline parallelism.""" - pass + def _get_split_model( + whole_model: nn.Module, + schedule_class: Type[PipelineScheduleSingle], + pp_mesh: DeviceMesh, + device: torch.device, + fqns_per_stage: list[list[str]], + ) -> tuple[PipelineStage, nn.Module]: + def get_stage_id_of_pp_rank(pp_mesh: DeviceMesh): + # NOTE: torch titan a more complicated way to get the stage id of pp rank + # since they also allow for multi-stage schedules + pp_rank = pp_mesh.get_local_rank() + return pp_rank + + @staticmethod + def _get_fqn_tree(fqns: list[str]) -> dict[str, Any]: + fqn_tree = {} + fqns = set(fqns) # Ensure unique FQNs + for fqn in fqns: + parts = fqn.split(".") + current_level = fqn_tree + for part in parts[:-1]: + if part not in current_level: + current_level[part] = {} + elif len(current_level) == 0: + raise ValueError(f"Part {part} of {fqn} already exists " "in the tree as a leaf node.") + current_level = current_level[part] + if parts[-1] in current_level: + raise ValueError( + f" Leaf of {fqn} has already been defined in the tree as an intermediadate node or leaf! " + "Cannot replace the existing node as a leaf." + ) + current_level[parts[-1]] = {} + + return fqn_tree + + def _build_stage_from_modules( + fqn_tree: dict[str, Any], module: nn.Module, module_name: Optional[str] = None + ) -> tuple[PipelineStage, nn.Module]: + if isinstance(module, nn.ModuleDict): + if module_name not in fqn_tree: + dict_modules = nn.ModuleDict({}) + else: + if len(fqn_tree) == 0: + # If the module is a leaf node, we can directly use it + dict_modules = module + else: + # If the module is not a leaf node, we need to build a staged module + # recursively from the FQN tree + dict_modules = {} + dict_module_names = [name for name in module.keys() if name in fqn_tree[module_name]] + for dict_module_name in dict_module_names: + dict_modules[dict_module_name] = _build_stage_from_modules( + fqn_tree=fqn_tree[module_name], + module=module[dict_module_name], + module_name=dict_module_name, + ) + dict_modules = nn.ModuleDict(dict_modules) + # setattr(module, module_name, dict_modules) + return dict_modules + + elif isinstance(module, nn.ModuleList): + if module_name not in fqn_tree: + list_modules = nn.ModuleList([]) + else: + if len(fqn_tree[module_name]) == 0: + # If the module is a leaf node, we can directly use it + list_modules = module + else: + # If the module is not a leaf node, we need to build a staged module + # recursively from the FQN tree + list_modules = [] + list_indices = [i for i in range(len(module)) if str(i) in fqn_tree[module_name]] + for idx in list_indices: + list_modules.append( + _build_stage_from_modules( + fqn_tree=fqn_tree[module_name], module=module[idx], module_name=str(idx) + ) + ) + list_modules = nn.ModuleList(list_modules) + # setattr(module, module_name, list_modules) + return list_modules + + else: # normal nn.Module + if module_name is not None and module_name not in fqn_tree: + # If the module is not in the FQN tree, set it to None + return None + elif module_name is not None and len(fqn_tree[module_name]) == 0: + # If the module is a leaf node, we can directly use it + return module + else: + # If the module is in the FQN tree, we need to build a staged module + # recursively from the FQN tree + for module_name, module_value in module.named_children(): + # If the module is not a leaf node, we need to build a staged module + # recursively from the FQN tree + staged_module = _build_stage_from_modules( + fqn_tree=fqn_tree, module=module_value, module_name=module_name + ) + setattr(module, module_name, staged_module) + + return module + + if not issubclass(schedule_class, PipelineScheduleSingle): + raise NotImplementedError("Only single-stage schedules are supported for pipeline parallelism.") + + # NOTE: For multi-stage schedule, e.g., Interleaved 1F1B, we have multiple stages per pp rank. + # This would need to be adapted accordingly in this case. + stage_idx = get_stage_id_of_pp_rank(pp_mesh) + module_names = fqns_per_stage[stage_idx] + whole_model = copy.deepcopy(whole_model) + fqn_tree = _get_fqn_tree(module_names) + stage_modules = _build_stage_from_modules(fqn_tree, whole_model) + stage = PipelineStage( + submodule=stage_modules, + stage_index=stage_idx, + num_stages=len(fqns_per_stage), + device=device, + group=pp_mesh.get_group("pp"), + ) + return stage, whole_model From 0ba8fbc4e0e8d20623d32844c90725a6beea3d09 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 15 Aug 2025 15:32:08 +0200 Subject: [PATCH 04/14] feat: added PP configs --- .../pipeline_parallelism_configs.py | 22 +++++++++++++++++++ .../parallelism/stages_generator_configs.py | 13 +++++++++++ 2 files changed, 35 insertions(+) create mode 100644 src/modalities/models/parallelism/pipeline_parallelism_configs.py create mode 100644 src/modalities/models/parallelism/stages_generator_configs.py diff --git a/src/modalities/models/parallelism/pipeline_parallelism_configs.py b/src/modalities/models/parallelism/pipeline_parallelism_configs.py new file mode 100644 index 000000000..61b8b5ba4 --- /dev/null +++ b/src/modalities/models/parallelism/pipeline_parallelism_configs.py @@ -0,0 +1,22 @@ +from typing import Annotated + +from pydantic import BaseModel, Field + +from modalities.config.pydantic_if_types import ( + PydanticDeviceMeshIFType, + PydanticPytorchModuleType, + PydanticStagesGeneratorType, +) + + +class FQNsPerStageGeneratorConfig(BaseModel): + pass + + +class PipelinedModelConfig(BaseModel): + whole_model: PydanticPytorchModuleType + stages_generator: PydanticStagesGeneratorType + device_mesh: PydanticDeviceMeshIFType + local_rank: Annotated[int, Field(strict=True, ge=0)] + pp_schedule_name: str + num_layers_per_stage: Annotated[int, Field(strict=True, ge=1)] diff --git a/src/modalities/models/parallelism/stages_generator_configs.py b/src/modalities/models/parallelism/stages_generator_configs.py new file mode 100644 index 000000000..610be7fdd --- /dev/null +++ b/src/modalities/models/parallelism/stages_generator_configs.py @@ -0,0 +1,13 @@ +from typing import Annotated + +from pydantic import BaseModel, Field + + +class FQNsPerStageGeneratorConfig(BaseModel): + pass + + +class GPT2LLMStagesGeneratorConfig(BaseModel): + num_model_layers: Annotated[int, Field(strict=True, ge=1)] + input_layer_equivalence: Annotated[int, Field(strict=True, ge=1)] = 1 + output_layer_equivalence: Annotated[int, Field(strict=True, ge=1)] = 1 From 4a41b6c3648a56a52af86af1a30de34c2215644c Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 15 Aug 2025 15:32:55 +0200 Subject: [PATCH 05/14] feat: wired up PP within dependency graph --- src/modalities/config/pydantic_if_types.py | 2 ++ src/modalities/registry/components.py | 8 +++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/modalities/config/pydantic_if_types.py b/src/modalities/config/pydantic_if_types.py index aa12a444d..eb7d0bce1 100644 --- a/src/modalities/config/pydantic_if_types.py +++ b/src/modalities/config/pydantic_if_types.py @@ -21,6 +21,7 @@ from modalities.inference.text.inference_component import TextInferenceComponent from modalities.logging_broker.subscriber import MessageSubscriberIF from modalities.loss_functions import Loss +from modalities.models.parallelism.pipeline_parallelism import StagesGenerator from modalities.nn.model_initialization.initialization_if import ModelInitializationIF from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF @@ -83,3 +84,4 @@ def __get_pydantic_core_schema__( PydanticDatasetBatchGeneratorIFType = Annotated[ DatasetBatchGeneratorIF, PydanticThirdPartyTypeIF(DatasetBatchGeneratorIF) ] +PydanticStagesGeneratorType = Annotated[StagesGenerator, PydanticThirdPartyTypeIF(StagesGenerator)] diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 28afab4bb..e6da12819 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -86,6 +86,10 @@ from modalities.models.gpt2.gpt2_model import GPT2LLMConfig from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig from modalities.models.model_factory import GPT2ModelFactory, ModelFactory +from modalities.models.parallelism.pipeline_parallelism import PipelineFactory +from modalities.models.parallelism.pipeline_parallelism_configs import PipelinedModelConfig +from modalities.models.parallelism.stages_generator import GPT2LLMStagesGenerator +from modalities.models.parallelism.stages_generator_configs import GPT2LLMStagesGeneratorConfig from modalities.nn.model_initialization.composed_initialization import ( ComposedInitializationRoutines, ComposedModelInitializationConfig, @@ -174,6 +178,9 @@ class ComponentEntity: ComponentEntity( "model", "debugging_enriched", ModelFactory.get_debugging_enriched_model, DebuggingEnrichedModelConfig ), + ComponentEntity("model", "pipelined", PipelineFactory.get_pipelined_model, PipelinedModelConfig), + # Pipeline Stages Generators + ComponentEntity("stages_generator", "gpt2_stages_generator", GPT2LLMStagesGenerator, GPT2LLMStagesGeneratorConfig), # Device mesh ComponentEntity("device_mesh", "default", get_device_mesh, DeviceMeshConfig), # weight initializers @@ -209,7 +216,6 @@ class ComponentEntity: # tokenizers ComponentEntity("tokenizer", "pretrained_hf_tokenizer", PreTrainedHFTokenizer, PreTrainedHFTokenizerConfig), ComponentEntity("tokenizer", "pretrained_sp_tokenizer", PreTrainedSPTokenizer, PreTrainedSPTokenizerConfig), - # ComponentEntity("tokenizer", "llama_tokenizer_fast", GPT2TokenizerFast, None), # TODO # datasets ComponentEntity("dataset", "mem_map_dataset", DatasetFactory.get_mem_map_dataset, MemMapDatasetConfig), ComponentEntity( From ee529b746da46c0eacba9b771c9a39db70152432 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 15 Aug 2025 15:33:34 +0200 Subject: [PATCH 06/14] feat: added FQN stages generator --- .../models/parallelism/stages_generator.py | 120 ++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 src/modalities/models/parallelism/stages_generator.py diff --git a/src/modalities/models/parallelism/stages_generator.py b/src/modalities/models/parallelism/stages_generator.py new file mode 100644 index 000000000..0a212672a --- /dev/null +++ b/src/modalities/models/parallelism/stages_generator.py @@ -0,0 +1,120 @@ +# Some portions of this implementation are inspired, adapted, or refactored +# from Meta's open-source project TorchTitan, +# licensed under the BSD 3-Clause License. + +import math +from abc import ABC, abstractmethod + + +class StagesGenerator(ABC): + def __init__(self, num_model_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1): + self._num_model_layers = num_model_layers + self._input_layer_equivalence = input_layer_equivalence + self._output_layer_equivalence = output_layer_equivalence + + def get_stages(self, num_layers_per_stage: int, pp_dims: int) -> list[list[str]]: + """ + Generate FQNs for each stage in a GPT-2 model. + + Args: + num_layers_per_stage (int): Number of layers per stage. + pp_dims (int): Number of pipeline parallel dimensions. + + Returns: + list[list[str]]: A list containing FQNs for each stage. + """ + + # calculate the number of stages + num_virtual_stages = math.ceil( + (self._num_model_layers + self._input_layer_equivalence + self._output_layer_equivalence) + / num_layers_per_stage + ) + if num_virtual_stages % pp_dims != 0: + raise ValueError( + f"Number of virtual stages {num_virtual_stages} is not divisible by parallel dimensions {pp_dims}. " + f"For reference: {self._num_model_layers=} {self._input_layer_equivalence=} " + f"{self._output_layer_equivalence=} {num_layers_per_stage=}" + ) + + stages_per_rank = num_virtual_stages // pp_dims + if stages_per_rank != 1: + raise ValueError( + f"Stages per rank {stages_per_rank} must be 1 for single-stage schedules. " + f"Please adjust {num_layers_per_stage=} to ensure each PP rank has exactly one stage." + ) + + # Potential split points for GPT-2 model with each potential split point + # listing the FQNs of the modules in that stage and the computational weight. + # The computational weight of the input and output modules are estimated + # based on the number of layers they correspond to. + potential_split_points = self._get_potential_split_points() + # Calculate the weight per stage based on the total weight and number of stages + weight_per_stage = math.ceil(sum(weight for _, weight in potential_split_points) / num_virtual_stages) + # pack the stages with the layers + next_split_point = 0 + module_names_per_stage: list[list[str]] = [] + for _ in range(num_virtual_stages): + stage_fqns = [] + stage_weight = 0 + while next_split_point < len(potential_split_points): + fqns, weight = potential_split_points[next_split_point] + if weight > weight_per_stage: + raise ValueError( + f"Weight of {weight} for {fqns} exceeds weight per stage {weight_per_stage}. " + "Please adjust the number of stages or the weight distribution." + ) + if stage_weight + weight > weight_per_stage: + break + stage_fqns.extend(fqns) + stage_weight += weight + next_split_point += 1 + module_names_per_stage.append(stage_fqns) + + return module_names_per_stage + + @abstractmethod + def _get_potential_split_points(self) -> list[tuple[list[str], int]]: + """ + Returns a list of potential split points for the GPT-2 model. + + Args: + num_model_layers (int): Total number of layers in the model. + input_layer_equivalence (int): Number of layers corresponding to the input layer. + output_layer_equivalence (int): Number of layers corresponding to the output layer. + + Returns: + list[tuple[list[str], int]]: A list containing tuples of FQNs and their computational weights. + """ + raise NotImplementedError("This method should be implemented by subclasses.") + + +class GPT2LLMStagesGenerator(StagesGenerator): + def __init__(self, num_model_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1): + super().__init__(num_model_layers, input_layer_equivalence, output_layer_equivalence) + + def _get_potential_split_points( + self, + ) -> list[tuple[list[str], int]]: + """ + Returns a list of potential split points for the GPT-2 model. + + Args: + num_model_layers (int): Total number of layers in the model. + input_layer_equivalence (int): Number of layers corresponding to the input layer. + output_layer_equivalence (int): Number of layers corresponding to the output layer. + + Returns: + list[tuple[list[str], int]]: A list containing tuples of FQNs and their computational weights. + """ + + # Potential split points for GPT-2 model with each potential split point + # listing the FQNs of the modules in that stage and the computational weight. + # The computational weight of the input and output modules are estimated + # based on the number of layers they correspond to. + potential_split_points = [ + (["transformer.wte", "transformer.wpe", "transformer.drop"], self._input_layer_equivalence), + *[([f"transformer.h.{i}"], 1) for i in range(self._num_model_layers)], + (["transformer.lm_head_norm", "transformer.lm_head"], self._output_layer_equivalence), + ] + + return potential_split_points From 625de592c02572db7626168d8504118909b768e1 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Mon, 18 Aug 2025 23:47:00 +0200 Subject: [PATCH 07/14] feat: implemented scheduled pipeline --- .../parallelism/pipeline_parallelism.py | 84 ++++++++++++++++++- 1 file changed, 80 insertions(+), 4 deletions(-) diff --git a/src/modalities/models/parallelism/pipeline_parallelism.py b/src/modalities/models/parallelism/pipeline_parallelism.py index 0af1551af..e9ac0c755 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism.py +++ b/src/modalities/models/parallelism/pipeline_parallelism.py @@ -3,6 +3,7 @@ # licensed under the BSD 3-Clause License. import copy +from enum import Enum from typing import Any, Optional, Type import torch @@ -11,22 +12,72 @@ from torch.distributed.pipelining import PipelineStage from torch.distributed.pipelining.schedules import PipelineScheduleSingle, get_schedule_class +from modalities.loss_functions import Loss from modalities.models.parallelism.stages_generator import StagesGenerator from modalities.running_env.fsdp.device_mesh import ParallelismDegrees +from modalities.utils.logger_utils import get_logger + +logger = get_logger(__name__) + + +class Pipeline: + def __init__( + self, + stage: PipelineStage, + model: nn.Module, + schedule: Optional[PipelineScheduleSingle] = None, + ): + self._stage = stage + self._model = model + self._schedule = schedule + + @property + def is_first_stage(self) -> bool: + return self._stage.is_first + + @property + def is_last_stage(self) -> bool: + return self._stage.is_last + + @property.setter + def schedule(self, schedule: PipelineScheduleSingle): + self._schedule = schedule + + +class PipelineSelectionTypes(Enum): + """Enum for pipeline selection types.""" + + STAGE = "stage" + MODEL = "model" + SCHEDULE = "schedule" + + +class ComponentSelectorFromPipeline: + @staticmethod + def select(pipeline: Pipeline, selection_type: PipelineSelectionTypes) -> Any: + """Selects a component from the pipeline based on the selection type.""" + if selection_type == PipelineSelectionTypes.STAGE: + return pipeline._stage + elif selection_type == PipelineSelectionTypes.MODEL: + return pipeline._model + elif selection_type == PipelineSelectionTypes.SCHEDULE: + return pipeline._schedule + else: + raise ValueError(f"Unsupported selection type: {selection_type}") class PipelineFactory: """Pipeline factory class to create pipelined models.""" @staticmethod - def get_pipelined_model( + def get_staged_pipeline( whole_model: nn.Module, stages_generator: StagesGenerator, device_mesh: DeviceMesh, local_rank: int, pp_schedule_name: str, num_layers_per_stage: int, - ) -> torch.nn.Module: + ) -> Pipeline: device = torch.device("cuda", local_rank) pp_dims = device_mesh[ParallelismDegrees.PP.value].size() @@ -42,6 +93,10 @@ def get_pipelined_model( raise ValueError( f"Unsupported pipeline schedule: {pp_schedule_name}. We only support single-stage schedules." ) + # torchtitan returns tuple of stages and models as depending on the schedule + # we might have multiple stages and model parts per rank. + # So far we don't support multi-stage schedules, which is why instead of tuples + # we work directly with the stage and model. stage, model = PipelineFactory._get_split_model( whole_model=whole_model, schedule_class=schedule_class, @@ -49,7 +104,9 @@ def get_pipelined_model( device=device, fqns_per_stage=fqns_per_stage, ) - return whole_model # TODO return pipelined model + + pipeline = Pipeline(stage=stage, model=model) + return pipeline @staticmethod def _get_split_model( @@ -171,4 +228,23 @@ def _build_stage_from_modules( device=device, group=pp_mesh.get_group("pp"), ) - return stage, whole_model + return stage, stage_modules + + @staticmethod + def get_scheduled_pipeline( + loss_fn: Loss, pp_schedule_name: str, batch_size: int, microbatch_size: int, pp_degree: int, pipeline: Pipeline + ) -> Pipeline: + # TODO: Addd validation in config that batch_size is divisible by microbatch_size + n_microbatches = batch_size // microbatch_size + num_total_stages = pp_degree + schedule_class = get_schedule_class(pp_schedule_name) + schedule = schedule_class( + stage=pipeline.stage, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + ) + logger.info( + f"Using pipeline schedule {schedule} with {n_microbatches} microbatches and {num_total_stages} stages." + ) + pipeline.schedule = schedule + return pipeline From 9677bd6f09cb5372220c7cdac305f8f99769375a Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Mon, 18 Aug 2025 23:47:42 +0200 Subject: [PATCH 08/14] feat: wired up scheduled and staged pipelines. --- src/modalities/config/pydantic_if_types.py | 3 ++- .../pipeline_parallelism_configs.py | 18 +++++++++++++++++- src/modalities/registry/components.py | 12 +++++++++--- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/modalities/config/pydantic_if_types.py b/src/modalities/config/pydantic_if_types.py index eb7d0bce1..c91ad4549 100644 --- a/src/modalities/config/pydantic_if_types.py +++ b/src/modalities/config/pydantic_if_types.py @@ -21,7 +21,7 @@ from modalities.inference.text.inference_component import TextInferenceComponent from modalities.logging_broker.subscriber import MessageSubscriberIF from modalities.loss_functions import Loss -from modalities.models.parallelism.pipeline_parallelism import StagesGenerator +from modalities.models.parallelism.pipeline_parallelism import Pipeline, StagesGenerator from modalities.nn.model_initialization.initialization_if import ModelInitializationIF from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF @@ -85,3 +85,4 @@ def __get_pydantic_core_schema__( DatasetBatchGeneratorIF, PydanticThirdPartyTypeIF(DatasetBatchGeneratorIF) ] PydanticStagesGeneratorType = Annotated[StagesGenerator, PydanticThirdPartyTypeIF(StagesGenerator)] +PydanticPipelineType = Annotated[Pipeline, PydanticThirdPartyTypeIF(Pipeline)] diff --git a/src/modalities/models/parallelism/pipeline_parallelism_configs.py b/src/modalities/models/parallelism/pipeline_parallelism_configs.py index 61b8b5ba4..e86cc46be 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism_configs.py +++ b/src/modalities/models/parallelism/pipeline_parallelism_configs.py @@ -4,19 +4,35 @@ from modalities.config.pydantic_if_types import ( PydanticDeviceMeshIFType, + PydanticPipelineType, PydanticPytorchModuleType, PydanticStagesGeneratorType, ) +from modalities.models.parallelism.pipeline_parallelism import PipelineSelectionTypes class FQNsPerStageGeneratorConfig(BaseModel): pass -class PipelinedModelConfig(BaseModel): +class StagedPipelineConfig(BaseModel): whole_model: PydanticPytorchModuleType stages_generator: PydanticStagesGeneratorType device_mesh: PydanticDeviceMeshIFType local_rank: Annotated[int, Field(strict=True, ge=0)] pp_schedule_name: str num_layers_per_stage: Annotated[int, Field(strict=True, ge=1)] + + +class ScheduledPipelineConfig(BaseModel): + loss_fn: PydanticPytorchModuleType + pp_schedule_name: str + batch_size: Annotated[int, Field(strict=True, ge=1)] + microbatch_size: Annotated[int, Field(strict=True, ge=1)] + pp_degree: Annotated[int, Field(strict=True, ge=2)] + pipeline: PydanticPipelineType + + +class ComponentSelectorFromPipelineConfig(BaseModel): + pipeline: PydanticPipelineType + selection_type: PipelineSelectionTypes diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index e6da12819..44d9820c4 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -86,8 +86,12 @@ from modalities.models.gpt2.gpt2_model import GPT2LLMConfig from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig from modalities.models.model_factory import GPT2ModelFactory, ModelFactory -from modalities.models.parallelism.pipeline_parallelism import PipelineFactory -from modalities.models.parallelism.pipeline_parallelism_configs import PipelinedModelConfig +from modalities.models.parallelism.pipeline_parallelism import ComponentSelectorFromPipeline, PipelineFactory +from modalities.models.parallelism.pipeline_parallelism_configs import ( + ComponentSelectorFromPipelineConfig, + ScheduledPipelineConfig, + StagedPipelineConfig, +) from modalities.models.parallelism.stages_generator import GPT2LLMStagesGenerator from modalities.models.parallelism.stages_generator_configs import GPT2LLMStagesGeneratorConfig from modalities.nn.model_initialization.composed_initialization import ( @@ -178,7 +182,9 @@ class ComponentEntity: ComponentEntity( "model", "debugging_enriched", ModelFactory.get_debugging_enriched_model, DebuggingEnrichedModelConfig ), - ComponentEntity("model", "pipelined", PipelineFactory.get_pipelined_model, PipelinedModelConfig), + ComponentEntity("pipeline", "staged", PipelineFactory.get_staged_pipeline, StagedPipelineConfig), + ComponentEntity("pipeline", "scheduled", PipelineFactory.get_scheduled_pipeline, ScheduledPipelineConfig), + ComponentEntity("pipeline", "selector", ComponentSelectorFromPipeline.select, ComponentSelectorFromPipelineConfig), # Pipeline Stages Generators ComponentEntity("stages_generator", "gpt2_stages_generator", GPT2LLMStagesGenerator, GPT2LLMStagesGeneratorConfig), # Device mesh From 7ac9edfd2578c3ab6c63ea29aed9057dfa22628b Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Mon, 18 Aug 2025 23:48:30 +0200 Subject: [PATCH 09/14] feat: added PP test config --- .../config_lorem_ipsum_long_fsdp2_pp.yaml | 395 ++++++++++++++++++ 1 file changed, 395 insertions(+) create mode 100644 config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml new file mode 100644 index 000000000..e5a3b61ce --- /dev/null +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml @@ -0,0 +1,395 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpoint_saving_path: data/checkpoints + train_dataset_path: ./data/lorem_ipsum_long.pbin + test_dataset_path: ./data/lorem_ipsum.pbin + intervals: + training_log_interval_in_steps: 1 + checkpointing_interval_in_steps: 32 + evaluation_interval_in_steps: 32 + consistency_enforcement: + enforce_tokens_per_step_consistency: true + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 1 + sequence_length: 256 + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_packed_mem_map_dataset_continuous + config: + dataset_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + num_ranks: ${settings.cuda_env.world_size} + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_target_steps: # for the batch progress subscriber + component_key: number_conversion + variant_key: num_steps_from_num_tokens + config: + num_ranks: ${settings.cuda_env.world_size} + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + global_num_tokens: ${settings.training_target.num_target_tokens} + sequence_length: ${settings.step_profile.sequence_length} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + num_seen_samples: 0 + last_step: -1 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +test_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.test_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +test_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: test + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + drop_last: true + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: test_dataloader + pass_type: BY_REFERENCE + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: gpipe + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 1 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + # maybe better to use the fsdp model and the schedule here + # instead of passing in the staged pipeline? + # If fsdp_model creates a copy then this is not in the scope of + # the staged pipeline. + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + + + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL + + +staged_pipeline: + component_key: model + variant_key: staged + config: + whole_model: + instance_key: model_raw + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: gpipe + num_layers_per_stage: 2 + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.01 + anneal_strategy: cos + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp2 + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + global_rank: ${settings.cuda_env.global_rank} + project: modalities_dcp_tests + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: wandb_storage + config_file_path: ${settings.config_file_path} + +# mfu_calculator: +# component_key: mfu_calculator +# variant_key: gpt2 +# config: +# n_layer: ${model_raw.config.n_layer} +# sequence_length: ${settings.step_profile.sequence_length} +# n_embd: ${model_raw.config.n_embd} +# world_size: ${settings.cuda_env.world_size} +# raw_model: +# instance_key: model_raw +# pass_type: BY_REFERENCE +# wrapped_model: +# instance_key: initialized_model +# pass_type: BY_REFERENCE \ No newline at end of file From d9f63c11d4f0f53a823fc60dc6f016169a185100 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Tue, 19 Aug 2025 14:39:09 +0200 Subject: [PATCH 10/14] refactor: staging is now fully instantiable --- .../config_lorem_ipsum_long_fsdp2_pp.yaml | 34 ++++++++++++++----- src/modalities/config/pydantic_if_types.py | 2 ++ .../parallelism/pipeline_parallelism.py | 27 ++++++++++++--- .../pipeline_parallelism_configs.py | 12 +++++-- .../parallelism/stages_generator_configs.py | 2 +- src/modalities/registry/components.py | 2 ++ 6 files changed, 63 insertions(+), 16 deletions(-) diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml index e5a3b61ce..fa2343b93 100644 --- a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml @@ -24,7 +24,7 @@ settings: enforce_last_step_checkpointed: false step_profile: gradient_accumulation_steps: 1 - local_train_micro_batch_size: 1 + local_train_micro_batch_size: 2 sequence_length: 256 training_target: num_target_tokens: @@ -190,13 +190,19 @@ app_state: instance_key: lr_scheduler pass_type: BY_REFERENCE + initialized_model: component_key: model variant_key: model_initialized config: model: - instance_key: fsdp_model - pass_type: BY_REFERENCE + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: scheduled_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL model_initializer: component_key: model_initialization variant_key: composed @@ -223,11 +229,21 @@ scheduled_pipeline: # If fsdp_model creates a copy then this is not in the scope of # the staged pipeline. pipeline: - instance_key: staged_pipeline - pass_type: BY_REFERENCE - - - + component_key: pipeline + variant_key: builder + config: + stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: STAGE + model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + fsdp_model: component_key: model variant_key: fsdp2_wrapped @@ -254,7 +270,7 @@ model_part: staged_pipeline: - component_key: model + component_key: pipeline variant_key: staged config: whole_model: diff --git a/src/modalities/config/pydantic_if_types.py b/src/modalities/config/pydantic_if_types.py index c91ad4549..2aeceb53c 100644 --- a/src/modalities/config/pydantic_if_types.py +++ b/src/modalities/config/pydantic_if_types.py @@ -7,6 +7,7 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import FSDPModule as FSDP2 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP1 +from torch.distributed.pipelining import PipelineStage from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import Sampler @@ -86,3 +87,4 @@ def __get_pydantic_core_schema__( ] PydanticStagesGeneratorType = Annotated[StagesGenerator, PydanticThirdPartyTypeIF(StagesGenerator)] PydanticPipelineType = Annotated[Pipeline, PydanticThirdPartyTypeIF(Pipeline)] +PydanticPipelineStageType = Annotated[PipelineStage, PydanticThirdPartyTypeIF(PipelineStage)] diff --git a/src/modalities/models/parallelism/pipeline_parallelism.py b/src/modalities/models/parallelism/pipeline_parallelism.py index e9ac0c755..b842fd75c 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism.py +++ b/src/modalities/models/parallelism/pipeline_parallelism.py @@ -39,7 +39,19 @@ def is_first_stage(self) -> bool: def is_last_stage(self) -> bool: return self._stage.is_last - @property.setter + @property + def stage(self) -> PipelineStage: + return self._stage + + @property + def model(self) -> nn.Module: + return self._model + + @property + def schedule(self) -> Optional[PipelineScheduleSingle]: + return self._schedule + + @schedule.setter def schedule(self, schedule: PipelineScheduleSingle): self._schedule = schedule @@ -47,9 +59,9 @@ def schedule(self, schedule: PipelineScheduleSingle): class PipelineSelectionTypes(Enum): """Enum for pipeline selection types.""" - STAGE = "stage" - MODEL = "model" - SCHEDULE = "schedule" + STAGE = "STAGE" + MODEL = "MODEL" + SCHEDULE = "SCHEDULE" class ComponentSelectorFromPipeline: @@ -69,6 +81,12 @@ def select(pipeline: Pipeline, selection_type: PipelineSelectionTypes) -> Any: class PipelineFactory: """Pipeline factory class to create pipelined models.""" + @staticmethod + def get_pipeline( + stage: PipelineStage, model: nn.Module, schedule: Optional[PipelineScheduleSingle] = None + ) -> Pipeline: + return Pipeline(stage=stage, model=model, schedule=schedule) + @staticmethod def get_staged_pipeline( whole_model: nn.Module, @@ -235,6 +253,7 @@ def get_scheduled_pipeline( loss_fn: Loss, pp_schedule_name: str, batch_size: int, microbatch_size: int, pp_degree: int, pipeline: Pipeline ) -> Pipeline: # TODO: Addd validation in config that batch_size is divisible by microbatch_size + # and n_microbatches must be >= pp_degree n_microbatches = batch_size // microbatch_size num_total_stages = pp_degree schedule_class = get_schedule_class(pp_schedule_name) diff --git a/src/modalities/models/parallelism/pipeline_parallelism_configs.py b/src/modalities/models/parallelism/pipeline_parallelism_configs.py index e86cc46be..c1aa23d48 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism_configs.py +++ b/src/modalities/models/parallelism/pipeline_parallelism_configs.py @@ -4,6 +4,8 @@ from modalities.config.pydantic_if_types import ( PydanticDeviceMeshIFType, + PydanticLossIFType, + PydanticPipelineStageType, PydanticPipelineType, PydanticPytorchModuleType, PydanticStagesGeneratorType, @@ -11,7 +13,7 @@ from modalities.models.parallelism.pipeline_parallelism import PipelineSelectionTypes -class FQNsPerStageGeneratorConfig(BaseModel): +class FQNsPerStageGeneratorConfig(BaseModel): # TODO duplicate pass @@ -25,7 +27,7 @@ class StagedPipelineConfig(BaseModel): class ScheduledPipelineConfig(BaseModel): - loss_fn: PydanticPytorchModuleType + loss_fn: PydanticLossIFType pp_schedule_name: str batch_size: Annotated[int, Field(strict=True, ge=1)] microbatch_size: Annotated[int, Field(strict=True, ge=1)] @@ -36,3 +38,9 @@ class ScheduledPipelineConfig(BaseModel): class ComponentSelectorFromPipelineConfig(BaseModel): pipeline: PydanticPipelineType selection_type: PipelineSelectionTypes + + +class PipelineConfig(BaseModel): + stage: PydanticPipelineStageType + model: PydanticPytorchModuleType + schedule: PydanticPipelineType | None = None diff --git a/src/modalities/models/parallelism/stages_generator_configs.py b/src/modalities/models/parallelism/stages_generator_configs.py index 610be7fdd..5d53f091d 100644 --- a/src/modalities/models/parallelism/stages_generator_configs.py +++ b/src/modalities/models/parallelism/stages_generator_configs.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field -class FQNsPerStageGeneratorConfig(BaseModel): +class FQNsPerStageGeneratorConfig(BaseModel): # TODO duplicate pass diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 44d9820c4..167a29894 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -89,6 +89,7 @@ from modalities.models.parallelism.pipeline_parallelism import ComponentSelectorFromPipeline, PipelineFactory from modalities.models.parallelism.pipeline_parallelism_configs import ( ComponentSelectorFromPipelineConfig, + PipelineConfig, ScheduledPipelineConfig, StagedPipelineConfig, ) @@ -185,6 +186,7 @@ class ComponentEntity: ComponentEntity("pipeline", "staged", PipelineFactory.get_staged_pipeline, StagedPipelineConfig), ComponentEntity("pipeline", "scheduled", PipelineFactory.get_scheduled_pipeline, ScheduledPipelineConfig), ComponentEntity("pipeline", "selector", ComponentSelectorFromPipeline.select, ComponentSelectorFromPipelineConfig), + ComponentEntity("pipeline", "builder", PipelineFactory.get_pipeline, PipelineConfig), # Pipeline Stages Generators ComponentEntity("stages_generator", "gpt2_stages_generator", GPT2LLMStagesGenerator, GPT2LLMStagesGeneratorConfig), # Device mesh From 83c87b9d6d6fbbb228bab31dccf1870b12679775 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Tue, 19 Aug 2025 14:39:58 +0200 Subject: [PATCH 11/14] feat: drafted pp e2e test for fwd/bwd pass --- .../pipeline_parallelism/__init__.py | 0 ...orem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml | 171 ++++++++++++++++++ .../test_pp_fwd_bwd_pass.py | 104 +++++++++++ 3 files changed, 275 insertions(+) create mode 100644 tests/fsdp2_parallelization/pipeline_parallelism/__init__.py create mode 100644 tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml create mode 100644 tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/__init__.py b/tests/fsdp2_parallelization/pipeline_parallelism/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml new file mode 100644 index 000000000..88182d266 --- /dev/null +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml @@ -0,0 +1,171 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 2 + sequence_length: 256 + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: scheduled_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: gpipe + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 1 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + pipeline: + component_key: pipeline + variant_key: builder + config: + stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: STAGE + model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: model_raw + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: gpipe + num_layers_per_stage: 2 + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py new file mode 100644 index 000000000..fc24223e9 --- /dev/null +++ b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py @@ -0,0 +1,104 @@ +import os +import tempfile +from pathlib import Path + +import pytest +import torch +import torch.multiprocessing as mp +import yaml +from pydantic import BaseModel + +from modalities.__main__ import Main +from modalities.config.config import ProcessGroupBackendType +from modalities.config.pydantic_if_types import PydanticFSDP2ModuleType, PydanticPipelineType +from tests.end2end_tests.custom_components import MultiProcessingCudaEnv + + +@pytest.fixture +def temp_file_path() -> Path: + # Create a NamedTemporaryFile that persists after closing (delete=False) + with tempfile.NamedTemporaryFile(delete=False) as tf: + file_path = tf.name + try: + yield Path(file_path) + finally: + # Clean up the file after the test + if os.path.exists(file_path): + os.remove(file_path) + + +class ComponentsInstantiationModel(BaseModel): + initialized_model: PydanticFSDP2ModuleType + scheduled_pipeline: PydanticPipelineType + + +@pytest.mark.skipif( + torch.cuda.device_count() < 8, + reason="This test requires 8 GPUs", +) +class TestPipelineParallelism: + def _get_tmp_sharding_config_path( + self, sharding_degree: int, tp_degree: int, pp_degree: int, temp_file_path: Path + ) -> Path: + working_dir = Path(os.path.dirname(__file__)) + config_file_path = working_dir / "configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml" + + with open(config_file_path, "r") as file: + config_string = file.read() + config_dict = yaml.safe_load(config_string) + config_dict["device_mesh"]["config"]["data_parallel_shard_degree"] = sharding_degree + config_dict["device_mesh"]["config"]["tensor_parallel_degree"] = tp_degree + config_dict["device_mesh"]["config"]["pipeline_parallel_degree"] = pp_degree + + # save to temporary file + with open(temp_file_path, "w") as file: + yaml.dump(config_dict, file) + + return temp_file_path + + def _get_components(self, config_file_path: Path) -> ComponentsInstantiationModel: + main_obj = Main(config_file_path) + components: ComponentsInstantiationModel = main_obj.build_components( + components_model_type=ComponentsInstantiationModel + ) + return components + + @pytest.mark.parametrize( + "sharding_degree, tp_degree, pp_degree, world_size", + [ + (2, 1, 2, 4), + # (2, 1, 4, 8), + # (2, 2, 2, 8), # TODO need to support this case + ], + ) + def test_pp(self, sharding_degree: int, tp_degree: int, pp_degree: int, world_size: int, temp_file_path: Path): + tmp_sharding_config_path = self._get_tmp_sharding_config_path( + sharding_degree=sharding_degree, + tp_degree=tp_degree, + pp_degree=pp_degree, + temp_file_path=temp_file_path, + ) + mp.spawn( + self._test_pp_impl, + args=(world_size, sharding_degree, tmp_sharding_config_path), + nprocs=world_size, + join=True, + ) + + def _test_pp_impl( + self, + process_id: int, + world_size: int, + sharding_degree: int, + gpt2_model_config_path: Path, + ): + # wraps the actual test function to be able to run it in a distributed multiprocessing setup + with MultiProcessingCudaEnv( + process_group_backend=ProcessGroupBackendType.nccl, + global_rank=process_id, + local_rank=process_id, + world_size=world_size, + rdvz_port=22356, + ): + self._get_components(gpt2_model_config_path) + pass From 95f24701fc9940e565893668e6d07cd6dc93b3ca Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 29 Aug 2025 09:55:12 +0200 Subject: [PATCH 12/14] refactor: renamings in the context of PP --- .../parallelism/pipeline_parallelism.py | 74 +++++++++---------- .../pipeline_parallelism_configs.py | 6 +- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/src/modalities/models/parallelism/pipeline_parallelism.py b/src/modalities/models/parallelism/pipeline_parallelism.py index b842fd75c..006d97a55 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism.py +++ b/src/modalities/models/parallelism/pipeline_parallelism.py @@ -23,57 +23,57 @@ class Pipeline: def __init__( self, - stage: PipelineStage, - model: nn.Module, - schedule: Optional[PipelineScheduleSingle] = None, + pp_stage: PipelineStage, + model_part: nn.Module, + pp_schedule: Optional[PipelineScheduleSingle] = None, ): - self._stage = stage - self._model = model - self._schedule = schedule + self._pp_stage = pp_stage + self._model_part = model_part + self._pp_schedule = pp_schedule @property - def is_first_stage(self) -> bool: - return self._stage.is_first + def is_first_pp_stage(self) -> bool: + return self._pp_stage.is_first @property - def is_last_stage(self) -> bool: - return self._stage.is_last + def is_last_pp_stage(self) -> bool: + return self._pp_stage.is_last @property - def stage(self) -> PipelineStage: - return self._stage + def pp_stage(self) -> PipelineStage: + return self._pp_stage @property - def model(self) -> nn.Module: - return self._model + def model_part(self) -> nn.Module: + return self._model_part @property - def schedule(self) -> Optional[PipelineScheduleSingle]: - return self._schedule + def pp_schedule(self) -> Optional[PipelineScheduleSingle]: + return self._pp_schedule - @schedule.setter - def schedule(self, schedule: PipelineScheduleSingle): - self._schedule = schedule + @pp_schedule.setter + def pp_schedule(self, schedule: PipelineScheduleSingle): + self._pp_schedule = schedule class PipelineSelectionTypes(Enum): """Enum for pipeline selection types.""" - STAGE = "STAGE" - MODEL = "MODEL" - SCHEDULE = "SCHEDULE" + PP_STAGE = "PP_STAGE" + MODEL_PART = "MODEL_PART" + PP_SCHEDULE = "PP_SCHEDULE" class ComponentSelectorFromPipeline: @staticmethod def select(pipeline: Pipeline, selection_type: PipelineSelectionTypes) -> Any: """Selects a component from the pipeline based on the selection type.""" - if selection_type == PipelineSelectionTypes.STAGE: - return pipeline._stage - elif selection_type == PipelineSelectionTypes.MODEL: - return pipeline._model - elif selection_type == PipelineSelectionTypes.SCHEDULE: - return pipeline._schedule + if selection_type == PipelineSelectionTypes.PP_STAGE: + return pipeline.pp_stage + elif selection_type == PipelineSelectionTypes.MODEL_PART: + return pipeline.model_part + elif selection_type == PipelineSelectionTypes.PP_SCHEDULE: + return pipeline.pp_schedule else: raise ValueError(f"Unsupported selection type: {selection_type}") @@ -83,9 +83,9 @@ class PipelineFactory: @staticmethod def get_pipeline( - stage: PipelineStage, model: nn.Module, schedule: Optional[PipelineScheduleSingle] = None + pp_stage: PipelineStage, model_part: nn.Module, pp_schedule: Optional[PipelineScheduleSingle] = None ) -> Pipeline: - return Pipeline(stage=stage, model=model, schedule=schedule) + return Pipeline(pp_stage=pp_stage, model_part=model_part, pp_schedule=pp_schedule) @staticmethod def get_staged_pipeline( @@ -115,7 +115,7 @@ def get_staged_pipeline( # we might have multiple stages and model parts per rank. # So far we don't support multi-stage schedules, which is why instead of tuples # we work directly with the stage and model. - stage, model = PipelineFactory._get_split_model( + pp_stage, model_part = PipelineFactory._get_split_model( whole_model=whole_model, schedule_class=schedule_class, pp_mesh=pp_mesh, @@ -123,7 +123,7 @@ def get_staged_pipeline( fqns_per_stage=fqns_per_stage, ) - pipeline = Pipeline(stage=stage, model=model) + pipeline = Pipeline(pp_stage=pp_stage, model_part=model_part) return pipeline @staticmethod @@ -256,14 +256,14 @@ def get_scheduled_pipeline( # and n_microbatches must be >= pp_degree n_microbatches = batch_size // microbatch_size num_total_stages = pp_degree - schedule_class = get_schedule_class(pp_schedule_name) - schedule = schedule_class( - stage=pipeline.stage, + pp_schedule_class = get_schedule_class(pp_schedule_name) + pp_schedule = pp_schedule_class( + stage=pipeline.pp_stage, n_microbatches=n_microbatches, loss_fn=loss_fn, ) logger.info( - f"Using pipeline schedule {schedule} with {n_microbatches} microbatches and {num_total_stages} stages." + f"Using pipeline schedule {pp_schedule} with {n_microbatches} microbatches and {num_total_stages} stages." ) - pipeline.schedule = schedule + pipeline.pp_schedule = pp_schedule return pipeline diff --git a/src/modalities/models/parallelism/pipeline_parallelism_configs.py b/src/modalities/models/parallelism/pipeline_parallelism_configs.py index c1aa23d48..831a6e15e 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism_configs.py +++ b/src/modalities/models/parallelism/pipeline_parallelism_configs.py @@ -41,6 +41,6 @@ class ComponentSelectorFromPipelineConfig(BaseModel): class PipelineConfig(BaseModel): - stage: PydanticPipelineStageType - model: PydanticPytorchModuleType - schedule: PydanticPipelineType | None = None + pp_stage: PydanticPipelineStageType + model_part: PydanticPytorchModuleType + pp_schedule: PydanticPipelineType | None = None From 521e5867559c984c71ab98b12d58a349c66d69cd Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 29 Aug 2025 09:56:39 +0200 Subject: [PATCH 13/14] chore: drafted the first PP test. --- ...orem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml | 18 +++++----- .../test_pp_fwd_bwd_pass.py | 34 ++++++++++++++++--- 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml index 88182d266..0ceb02a53 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml @@ -11,7 +11,7 @@ settings: world_size: ${cuda_env:WORLD_SIZE} step_profile: gradient_accumulation_steps: 1 - local_train_micro_batch_size: 2 + local_train_micro_batch_size: 4 sequence_length: 256 loss_fn: @@ -42,7 +42,7 @@ initialized_model: pipeline: instance_key: scheduled_pipeline pass_type: BY_REFERENCE - selection_type: MODEL + selection_type: MODEL_PART model_initializer: component_key: model_initialization variant_key: composed @@ -62,21 +62,21 @@ scheduled_pipeline: pass_type: BY_REFERENCE pp_schedule_name: gpipe batch_size: ${settings.step_profile.local_train_micro_batch_size} - microbatch_size: 1 + microbatch_size: 2 pp_degree: ${device_mesh.config.pipeline_parallel_degree} pipeline: component_key: pipeline variant_key: builder config: - stage: + pp_stage: component_key: pipeline variant_key: selector config: pipeline: instance_key: staged_pipeline pass_type: BY_REFERENCE - selection_type: STAGE - model: + selection_type: PP_STAGE + model_part: instance_key: fsdp_model pass_type: BY_REFERENCE @@ -102,7 +102,7 @@ model_part: pipeline: instance_key: staged_pipeline pass_type: BY_REFERENCE - selection_type: MODEL + selection_type: MODEL_PART staged_pipeline: component_key: pipeline @@ -123,7 +123,7 @@ staged_pipeline: pass_type: BY_REFERENCE local_rank: ${settings.cuda_env.local_rank} pp_schedule_name: gpipe - num_layers_per_stage: 2 + num_layers_per_stage: 4 model_raw: component_key: model @@ -136,7 +136,7 @@ model_raw: sequence_length: ${settings.step_profile.sequence_length} prediction_key: ${loss_fn.config.prediction_key} vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency - n_layer: 2 + n_layer: 6 n_head_q: 8 n_head_kv: 4 ffn_hidden: 128 diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py index fc24223e9..6f861c1ea 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py +++ b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py @@ -11,6 +11,7 @@ from modalities.__main__ import Main from modalities.config.config import ProcessGroupBackendType from modalities.config.pydantic_if_types import PydanticFSDP2ModuleType, PydanticPipelineType +from modalities.models.parallelism.pipeline_parallelism import Pipeline from tests.end2end_tests.custom_components import MultiProcessingCudaEnv @@ -80,7 +81,7 @@ def test_pp(self, sharding_degree: int, tp_degree: int, pp_degree: int, world_si ) mp.spawn( self._test_pp_impl, - args=(world_size, sharding_degree, tmp_sharding_config_path), + args=(world_size, tmp_sharding_config_path), nprocs=world_size, join=True, ) @@ -89,7 +90,6 @@ def _test_pp_impl( self, process_id: int, world_size: int, - sharding_degree: int, gpt2_model_config_path: Path, ): # wraps the actual test function to be able to run it in a distributed multiprocessing setup @@ -100,5 +100,31 @@ def _test_pp_impl( world_size=world_size, rdvz_port=22356, ): - self._get_components(gpt2_model_config_path) - pass + components = self._get_components(gpt2_model_config_path) + scheduled_pipeline = components.scheduled_pipeline + vocab_size = 50304 + sequence_length = 256 + batch_size = 4 + sequences = torch.randint(0, vocab_size, (batch_size, sequence_length)) + targets = sequences[:, 1:].contiguous() + inputs = sequences[:, :-1].contiguous() + self._forward_step(scheduled_pipeline, inputs, targets) + + def _forward_step(self, scheduled_pipeline: Pipeline, inputs: torch.Tensor, targets: torch.Tensor): + """Runs a forward step on the model.""" + pp_schedule = scheduled_pipeline.pp_schedule + targets, losses = (targets, []) if scheduled_pipeline.is_last_pp_stage else (None, None) + if scheduled_pipeline.is_first_pp_stage: # first stage + pp_schedule.step(inputs, target=targets, losses=losses, input_batch=inputs) + else: # non-first stage + pp_schedule.step(target=targets, losses=losses, input_batch=inputs) + + # accumulate losses across pipeline microbatches + # TODO: PP+FSDP unexpectedly puts the loss back to the CPU + ( + torch.mean(torch.stack(losses)).to(self.device) + if self.pp_has_last_stage + else torch.tensor([-1.0], device=self.device) + ) + + # return output From 002b0ae557411351dc274be97f7e0e6c59c0afd8 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Sat, 30 Aug 2025 00:46:18 +0200 Subject: [PATCH 14/14] chore: pp config fixes --- .../training/config_lorem_ipsum_long_fsdp2_pp.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml index fa2343b93..381550a20 100644 --- a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml @@ -202,7 +202,7 @@ initialized_model: pipeline: instance_key: scheduled_pipeline pass_type: BY_REFERENCE - selection_type: MODEL + selection_type: MODEL_PART model_initializer: component_key: model_initialization variant_key: composed @@ -232,15 +232,15 @@ scheduled_pipeline: component_key: pipeline variant_key: builder config: - stage: + pp_stage: component_key: pipeline variant_key: selector config: pipeline: instance_key: staged_pipeline pass_type: BY_REFERENCE - selection_type: STAGE - model: + selection_type: PP_STAGE + model_part: instance_key: fsdp_model pass_type: BY_REFERENCE @@ -266,7 +266,7 @@ model_part: pipeline: instance_key: staged_pipeline pass_type: BY_REFERENCE - selection_type: MODEL + selection_type: MODEL_PART staged_pipeline: