diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml new file mode 100644 index 000000000..381550a20 --- /dev/null +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml @@ -0,0 +1,411 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpoint_saving_path: data/checkpoints + train_dataset_path: ./data/lorem_ipsum_long.pbin + test_dataset_path: ./data/lorem_ipsum.pbin + intervals: + training_log_interval_in_steps: 1 + checkpointing_interval_in_steps: 32 + evaluation_interval_in_steps: 32 + consistency_enforcement: + enforce_tokens_per_step_consistency: true + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 2 + sequence_length: 256 + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_packed_mem_map_dataset_continuous + config: + dataset_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + num_ranks: ${settings.cuda_env.world_size} + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_target_steps: # for the batch progress subscriber + component_key: number_conversion + variant_key: num_steps_from_num_tokens + config: + num_ranks: ${settings.cuda_env.world_size} + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + global_num_tokens: ${settings.training_target.num_target_tokens} + sequence_length: ${settings.step_profile.sequence_length} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + num_seen_samples: 0 + last_step: -1 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +test_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.test_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +test_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: test + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + drop_last: true + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: test_dataloader + pass_type: BY_REFERENCE + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: scheduled_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: gpipe + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 1 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + # maybe better to use the fsdp model and the schedule here + # instead of passing in the staged pipeline? + # If fsdp_model creates a copy then this is not in the scope of + # the staged pipeline. + pipeline: + component_key: pipeline + variant_key: builder + config: + pp_stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: PP_STAGE + model_part: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: model_raw + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: gpipe + num_layers_per_stage: 2 + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.01 + anneal_strategy: cos + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp2 + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + global_rank: ${settings.cuda_env.global_rank} + project: modalities_dcp_tests + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: wandb_storage + config_file_path: ${settings.config_file_path} + +# mfu_calculator: +# component_key: mfu_calculator +# variant_key: gpt2 +# config: +# n_layer: ${model_raw.config.n_layer} +# sequence_length: ${settings.step_profile.sequence_length} +# n_embd: ${model_raw.config.n_embd} +# world_size: ${settings.cuda_env.world_size} +# raw_model: +# instance_key: model_raw +# pass_type: BY_REFERENCE +# wrapped_model: +# instance_key: initialized_model +# pass_type: BY_REFERENCE \ No newline at end of file diff --git a/src/modalities/config/pydantic_if_types.py b/src/modalities/config/pydantic_if_types.py index aa12a444d..2aeceb53c 100644 --- a/src/modalities/config/pydantic_if_types.py +++ b/src/modalities/config/pydantic_if_types.py @@ -7,6 +7,7 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import FSDPModule as FSDP2 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP1 +from torch.distributed.pipelining import PipelineStage from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import Sampler @@ -21,6 +22,7 @@ from modalities.inference.text.inference_component import TextInferenceComponent from modalities.logging_broker.subscriber import MessageSubscriberIF from modalities.loss_functions import Loss +from modalities.models.parallelism.pipeline_parallelism import Pipeline, StagesGenerator from modalities.nn.model_initialization.initialization_if import ModelInitializationIF from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF @@ -83,3 +85,6 @@ def __get_pydantic_core_schema__( PydanticDatasetBatchGeneratorIFType = Annotated[ DatasetBatchGeneratorIF, PydanticThirdPartyTypeIF(DatasetBatchGeneratorIF) ] +PydanticStagesGeneratorType = Annotated[StagesGenerator, PydanticThirdPartyTypeIF(StagesGenerator)] +PydanticPipelineType = Annotated[Pipeline, PydanticThirdPartyTypeIF(Pipeline)] +PydanticPipelineStageType = Annotated[PipelineStage, PydanticThirdPartyTypeIF(PipelineStage)] diff --git a/src/modalities/models/parallelism/__init__.py b/src/modalities/models/parallelism/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/modalities/models/parallelism/pipeline_parallelism.py b/src/modalities/models/parallelism/pipeline_parallelism.py new file mode 100644 index 000000000..006d97a55 --- /dev/null +++ b/src/modalities/models/parallelism/pipeline_parallelism.py @@ -0,0 +1,269 @@ +# Some portions of this implementation are inspired, adapted, or refactored +# from Meta's open-source project TorchTitan, +# licensed under the BSD 3-Clause License. + +import copy +from enum import Enum +from typing import Any, Optional, Type + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.schedules import PipelineScheduleSingle, get_schedule_class + +from modalities.loss_functions import Loss +from modalities.models.parallelism.stages_generator import StagesGenerator +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees +from modalities.utils.logger_utils import get_logger + +logger = get_logger(__name__) + + +class Pipeline: + def __init__( + self, + pp_stage: PipelineStage, + model_part: nn.Module, + pp_schedule: Optional[PipelineScheduleSingle] = None, + ): + self._pp_stage = pp_stage + self._model_part = model_part + self._pp_schedule = pp_schedule + + @property + def is_first_pp_stage(self) -> bool: + return self._pp_stage.is_first + + @property + def is_last_pp_stage(self) -> bool: + return self._pp_stage.is_last + + @property + def pp_stage(self) -> PipelineStage: + return self._pp_stage + + @property + def model_part(self) -> nn.Module: + return self._model_part + + @property + def pp_schedule(self) -> Optional[PipelineScheduleSingle]: + return self._pp_schedule + + @pp_schedule.setter + def pp_schedule(self, schedule: PipelineScheduleSingle): + self._pp_schedule = schedule + + +class PipelineSelectionTypes(Enum): + """Enum for pipeline selection types.""" + + PP_STAGE = "PP_STAGE" + MODEL_PART = "MODEL_PART" + PP_SCHEDULE = "PP_SCHEDULE" + + +class ComponentSelectorFromPipeline: + @staticmethod + def select(pipeline: Pipeline, selection_type: PipelineSelectionTypes) -> Any: + """Selects a component from the pipeline based on the selection type.""" + if selection_type == PipelineSelectionTypes.PP_STAGE: + return pipeline.pp_stage + elif selection_type == PipelineSelectionTypes.MODEL_PART: + return pipeline.model_part + elif selection_type == PipelineSelectionTypes.PP_SCHEDULE: + return pipeline.pp_schedule + else: + raise ValueError(f"Unsupported selection type: {selection_type}") + + +class PipelineFactory: + """Pipeline factory class to create pipelined models.""" + + @staticmethod + def get_pipeline( + pp_stage: PipelineStage, model_part: nn.Module, pp_schedule: Optional[PipelineScheduleSingle] = None + ) -> Pipeline: + return Pipeline(pp_stage=pp_stage, model_part=model_part, pp_schedule=pp_schedule) + + @staticmethod + def get_staged_pipeline( + whole_model: nn.Module, + stages_generator: StagesGenerator, + device_mesh: DeviceMesh, + local_rank: int, + pp_schedule_name: str, + num_layers_per_stage: int, + ) -> Pipeline: + device = torch.device("cuda", local_rank) + pp_dims = device_mesh[ParallelismDegrees.PP.value].size() + + fqns_per_stage = stages_generator.get_stages( + num_layers_per_stage=num_layers_per_stage, + pp_dims=pp_dims, + ) + + pp_mesh = device_mesh[ParallelismDegrees.PP.value] + schedule_class = get_schedule_class(pp_schedule_name) + is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) + if not is_single_stage_schedule: + raise ValueError( + f"Unsupported pipeline schedule: {pp_schedule_name}. We only support single-stage schedules." + ) + # torchtitan returns tuple of stages and models as depending on the schedule + # we might have multiple stages and model parts per rank. + # So far we don't support multi-stage schedules, which is why instead of tuples + # we work directly with the stage and model. + pp_stage, model_part = PipelineFactory._get_split_model( + whole_model=whole_model, + schedule_class=schedule_class, + pp_mesh=pp_mesh, + device=device, + fqns_per_stage=fqns_per_stage, + ) + + pipeline = Pipeline(pp_stage=pp_stage, model_part=model_part) + return pipeline + + @staticmethod + def _get_split_model( + whole_model: nn.Module, + schedule_class: Type[PipelineScheduleSingle], + pp_mesh: DeviceMesh, + device: torch.device, + fqns_per_stage: list[list[str]], + ) -> tuple[PipelineStage, nn.Module]: + def get_stage_id_of_pp_rank(pp_mesh: DeviceMesh): + # NOTE: torch titan a more complicated way to get the stage id of pp rank + # since they also allow for multi-stage schedules + pp_rank = pp_mesh.get_local_rank() + return pp_rank + + @staticmethod + def _get_fqn_tree(fqns: list[str]) -> dict[str, Any]: + fqn_tree = {} + fqns = set(fqns) # Ensure unique FQNs + for fqn in fqns: + parts = fqn.split(".") + current_level = fqn_tree + for part in parts[:-1]: + if part not in current_level: + current_level[part] = {} + elif len(current_level) == 0: + raise ValueError(f"Part {part} of {fqn} already exists " "in the tree as a leaf node.") + current_level = current_level[part] + if parts[-1] in current_level: + raise ValueError( + f" Leaf of {fqn} has already been defined in the tree as an intermediadate node or leaf! " + "Cannot replace the existing node as a leaf." + ) + current_level[parts[-1]] = {} + + return fqn_tree + + def _build_stage_from_modules( + fqn_tree: dict[str, Any], module: nn.Module, module_name: Optional[str] = None + ) -> tuple[PipelineStage, nn.Module]: + if isinstance(module, nn.ModuleDict): + if module_name not in fqn_tree: + dict_modules = nn.ModuleDict({}) + else: + if len(fqn_tree) == 0: + # If the module is a leaf node, we can directly use it + dict_modules = module + else: + # If the module is not a leaf node, we need to build a staged module + # recursively from the FQN tree + dict_modules = {} + dict_module_names = [name for name in module.keys() if name in fqn_tree[module_name]] + for dict_module_name in dict_module_names: + dict_modules[dict_module_name] = _build_stage_from_modules( + fqn_tree=fqn_tree[module_name], + module=module[dict_module_name], + module_name=dict_module_name, + ) + dict_modules = nn.ModuleDict(dict_modules) + # setattr(module, module_name, dict_modules) + return dict_modules + + elif isinstance(module, nn.ModuleList): + if module_name not in fqn_tree: + list_modules = nn.ModuleList([]) + else: + if len(fqn_tree[module_name]) == 0: + # If the module is a leaf node, we can directly use it + list_modules = module + else: + # If the module is not a leaf node, we need to build a staged module + # recursively from the FQN tree + list_modules = [] + list_indices = [i for i in range(len(module)) if str(i) in fqn_tree[module_name]] + for idx in list_indices: + list_modules.append( + _build_stage_from_modules( + fqn_tree=fqn_tree[module_name], module=module[idx], module_name=str(idx) + ) + ) + list_modules = nn.ModuleList(list_modules) + # setattr(module, module_name, list_modules) + return list_modules + + else: # normal nn.Module + if module_name is not None and module_name not in fqn_tree: + # If the module is not in the FQN tree, set it to None + return None + elif module_name is not None and len(fqn_tree[module_name]) == 0: + # If the module is a leaf node, we can directly use it + return module + else: + # If the module is in the FQN tree, we need to build a staged module + # recursively from the FQN tree + for module_name, module_value in module.named_children(): + # If the module is not a leaf node, we need to build a staged module + # recursively from the FQN tree + staged_module = _build_stage_from_modules( + fqn_tree=fqn_tree, module=module_value, module_name=module_name + ) + setattr(module, module_name, staged_module) + + return module + + if not issubclass(schedule_class, PipelineScheduleSingle): + raise NotImplementedError("Only single-stage schedules are supported for pipeline parallelism.") + + # NOTE: For multi-stage schedule, e.g., Interleaved 1F1B, we have multiple stages per pp rank. + # This would need to be adapted accordingly in this case. + stage_idx = get_stage_id_of_pp_rank(pp_mesh) + module_names = fqns_per_stage[stage_idx] + whole_model = copy.deepcopy(whole_model) + fqn_tree = _get_fqn_tree(module_names) + stage_modules = _build_stage_from_modules(fqn_tree, whole_model) + stage = PipelineStage( + submodule=stage_modules, + stage_index=stage_idx, + num_stages=len(fqns_per_stage), + device=device, + group=pp_mesh.get_group("pp"), + ) + return stage, stage_modules + + @staticmethod + def get_scheduled_pipeline( + loss_fn: Loss, pp_schedule_name: str, batch_size: int, microbatch_size: int, pp_degree: int, pipeline: Pipeline + ) -> Pipeline: + # TODO: Addd validation in config that batch_size is divisible by microbatch_size + # and n_microbatches must be >= pp_degree + n_microbatches = batch_size // microbatch_size + num_total_stages = pp_degree + pp_schedule_class = get_schedule_class(pp_schedule_name) + pp_schedule = pp_schedule_class( + stage=pipeline.pp_stage, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + ) + logger.info( + f"Using pipeline schedule {pp_schedule} with {n_microbatches} microbatches and {num_total_stages} stages." + ) + pipeline.pp_schedule = pp_schedule + return pipeline diff --git a/src/modalities/models/parallelism/pipeline_parallelism_configs.py b/src/modalities/models/parallelism/pipeline_parallelism_configs.py new file mode 100644 index 000000000..831a6e15e --- /dev/null +++ b/src/modalities/models/parallelism/pipeline_parallelism_configs.py @@ -0,0 +1,46 @@ +from typing import Annotated + +from pydantic import BaseModel, Field + +from modalities.config.pydantic_if_types import ( + PydanticDeviceMeshIFType, + PydanticLossIFType, + PydanticPipelineStageType, + PydanticPipelineType, + PydanticPytorchModuleType, + PydanticStagesGeneratorType, +) +from modalities.models.parallelism.pipeline_parallelism import PipelineSelectionTypes + + +class FQNsPerStageGeneratorConfig(BaseModel): # TODO duplicate + pass + + +class StagedPipelineConfig(BaseModel): + whole_model: PydanticPytorchModuleType + stages_generator: PydanticStagesGeneratorType + device_mesh: PydanticDeviceMeshIFType + local_rank: Annotated[int, Field(strict=True, ge=0)] + pp_schedule_name: str + num_layers_per_stage: Annotated[int, Field(strict=True, ge=1)] + + +class ScheduledPipelineConfig(BaseModel): + loss_fn: PydanticLossIFType + pp_schedule_name: str + batch_size: Annotated[int, Field(strict=True, ge=1)] + microbatch_size: Annotated[int, Field(strict=True, ge=1)] + pp_degree: Annotated[int, Field(strict=True, ge=2)] + pipeline: PydanticPipelineType + + +class ComponentSelectorFromPipelineConfig(BaseModel): + pipeline: PydanticPipelineType + selection_type: PipelineSelectionTypes + + +class PipelineConfig(BaseModel): + pp_stage: PydanticPipelineStageType + model_part: PydanticPytorchModuleType + pp_schedule: PydanticPipelineType | None = None diff --git a/src/modalities/models/parallelism/stages_generator.py b/src/modalities/models/parallelism/stages_generator.py new file mode 100644 index 000000000..0a212672a --- /dev/null +++ b/src/modalities/models/parallelism/stages_generator.py @@ -0,0 +1,120 @@ +# Some portions of this implementation are inspired, adapted, or refactored +# from Meta's open-source project TorchTitan, +# licensed under the BSD 3-Clause License. + +import math +from abc import ABC, abstractmethod + + +class StagesGenerator(ABC): + def __init__(self, num_model_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1): + self._num_model_layers = num_model_layers + self._input_layer_equivalence = input_layer_equivalence + self._output_layer_equivalence = output_layer_equivalence + + def get_stages(self, num_layers_per_stage: int, pp_dims: int) -> list[list[str]]: + """ + Generate FQNs for each stage in a GPT-2 model. + + Args: + num_layers_per_stage (int): Number of layers per stage. + pp_dims (int): Number of pipeline parallel dimensions. + + Returns: + list[list[str]]: A list containing FQNs for each stage. + """ + + # calculate the number of stages + num_virtual_stages = math.ceil( + (self._num_model_layers + self._input_layer_equivalence + self._output_layer_equivalence) + / num_layers_per_stage + ) + if num_virtual_stages % pp_dims != 0: + raise ValueError( + f"Number of virtual stages {num_virtual_stages} is not divisible by parallel dimensions {pp_dims}. " + f"For reference: {self._num_model_layers=} {self._input_layer_equivalence=} " + f"{self._output_layer_equivalence=} {num_layers_per_stage=}" + ) + + stages_per_rank = num_virtual_stages // pp_dims + if stages_per_rank != 1: + raise ValueError( + f"Stages per rank {stages_per_rank} must be 1 for single-stage schedules. " + f"Please adjust {num_layers_per_stage=} to ensure each PP rank has exactly one stage." + ) + + # Potential split points for GPT-2 model with each potential split point + # listing the FQNs of the modules in that stage and the computational weight. + # The computational weight of the input and output modules are estimated + # based on the number of layers they correspond to. + potential_split_points = self._get_potential_split_points() + # Calculate the weight per stage based on the total weight and number of stages + weight_per_stage = math.ceil(sum(weight for _, weight in potential_split_points) / num_virtual_stages) + # pack the stages with the layers + next_split_point = 0 + module_names_per_stage: list[list[str]] = [] + for _ in range(num_virtual_stages): + stage_fqns = [] + stage_weight = 0 + while next_split_point < len(potential_split_points): + fqns, weight = potential_split_points[next_split_point] + if weight > weight_per_stage: + raise ValueError( + f"Weight of {weight} for {fqns} exceeds weight per stage {weight_per_stage}. " + "Please adjust the number of stages or the weight distribution." + ) + if stage_weight + weight > weight_per_stage: + break + stage_fqns.extend(fqns) + stage_weight += weight + next_split_point += 1 + module_names_per_stage.append(stage_fqns) + + return module_names_per_stage + + @abstractmethod + def _get_potential_split_points(self) -> list[tuple[list[str], int]]: + """ + Returns a list of potential split points for the GPT-2 model. + + Args: + num_model_layers (int): Total number of layers in the model. + input_layer_equivalence (int): Number of layers corresponding to the input layer. + output_layer_equivalence (int): Number of layers corresponding to the output layer. + + Returns: + list[tuple[list[str], int]]: A list containing tuples of FQNs and their computational weights. + """ + raise NotImplementedError("This method should be implemented by subclasses.") + + +class GPT2LLMStagesGenerator(StagesGenerator): + def __init__(self, num_model_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1): + super().__init__(num_model_layers, input_layer_equivalence, output_layer_equivalence) + + def _get_potential_split_points( + self, + ) -> list[tuple[list[str], int]]: + """ + Returns a list of potential split points for the GPT-2 model. + + Args: + num_model_layers (int): Total number of layers in the model. + input_layer_equivalence (int): Number of layers corresponding to the input layer. + output_layer_equivalence (int): Number of layers corresponding to the output layer. + + Returns: + list[tuple[list[str], int]]: A list containing tuples of FQNs and their computational weights. + """ + + # Potential split points for GPT-2 model with each potential split point + # listing the FQNs of the modules in that stage and the computational weight. + # The computational weight of the input and output modules are estimated + # based on the number of layers they correspond to. + potential_split_points = [ + (["transformer.wte", "transformer.wpe", "transformer.drop"], self._input_layer_equivalence), + *[([f"transformer.h.{i}"], 1) for i in range(self._num_model_layers)], + (["transformer.lm_head_norm", "transformer.lm_head"], self._output_layer_equivalence), + ] + + return potential_split_points diff --git a/src/modalities/models/parallelism/stages_generator_configs.py b/src/modalities/models/parallelism/stages_generator_configs.py new file mode 100644 index 000000000..5d53f091d --- /dev/null +++ b/src/modalities/models/parallelism/stages_generator_configs.py @@ -0,0 +1,13 @@ +from typing import Annotated + +from pydantic import BaseModel, Field + + +class FQNsPerStageGeneratorConfig(BaseModel): # TODO duplicate + pass + + +class GPT2LLMStagesGeneratorConfig(BaseModel): + num_model_layers: Annotated[int, Field(strict=True, ge=1)] + input_layer_equivalence: Annotated[int, Field(strict=True, ge=1)] = 1 + output_layer_equivalence: Annotated[int, Field(strict=True, ge=1)] = 1 diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 28afab4bb..167a29894 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -86,6 +86,15 @@ from modalities.models.gpt2.gpt2_model import GPT2LLMConfig from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig from modalities.models.model_factory import GPT2ModelFactory, ModelFactory +from modalities.models.parallelism.pipeline_parallelism import ComponentSelectorFromPipeline, PipelineFactory +from modalities.models.parallelism.pipeline_parallelism_configs import ( + ComponentSelectorFromPipelineConfig, + PipelineConfig, + ScheduledPipelineConfig, + StagedPipelineConfig, +) +from modalities.models.parallelism.stages_generator import GPT2LLMStagesGenerator +from modalities.models.parallelism.stages_generator_configs import GPT2LLMStagesGeneratorConfig from modalities.nn.model_initialization.composed_initialization import ( ComposedInitializationRoutines, ComposedModelInitializationConfig, @@ -174,6 +183,12 @@ class ComponentEntity: ComponentEntity( "model", "debugging_enriched", ModelFactory.get_debugging_enriched_model, DebuggingEnrichedModelConfig ), + ComponentEntity("pipeline", "staged", PipelineFactory.get_staged_pipeline, StagedPipelineConfig), + ComponentEntity("pipeline", "scheduled", PipelineFactory.get_scheduled_pipeline, ScheduledPipelineConfig), + ComponentEntity("pipeline", "selector", ComponentSelectorFromPipeline.select, ComponentSelectorFromPipelineConfig), + ComponentEntity("pipeline", "builder", PipelineFactory.get_pipeline, PipelineConfig), + # Pipeline Stages Generators + ComponentEntity("stages_generator", "gpt2_stages_generator", GPT2LLMStagesGenerator, GPT2LLMStagesGeneratorConfig), # Device mesh ComponentEntity("device_mesh", "default", get_device_mesh, DeviceMeshConfig), # weight initializers @@ -209,7 +224,6 @@ class ComponentEntity: # tokenizers ComponentEntity("tokenizer", "pretrained_hf_tokenizer", PreTrainedHFTokenizer, PreTrainedHFTokenizerConfig), ComponentEntity("tokenizer", "pretrained_sp_tokenizer", PreTrainedSPTokenizer, PreTrainedSPTokenizerConfig), - # ComponentEntity("tokenizer", "llama_tokenizer_fast", GPT2TokenizerFast, None), # TODO # datasets ComponentEntity("dataset", "mem_map_dataset", DatasetFactory.get_mem_map_dataset, MemMapDatasetConfig), ComponentEntity( diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/__init__.py b/tests/fsdp2_parallelization/pipeline_parallelism/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml new file mode 100644 index 000000000..0ceb02a53 --- /dev/null +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml @@ -0,0 +1,171 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 4 + sequence_length: 256 + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: scheduled_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: gpipe + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 2 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + pipeline: + component_key: pipeline + variant_key: builder + config: + pp_stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: PP_STAGE + model_part: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: model_raw + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: gpipe + num_layers_per_stage: 4 + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 6 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py new file mode 100644 index 000000000..6f861c1ea --- /dev/null +++ b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py @@ -0,0 +1,130 @@ +import os +import tempfile +from pathlib import Path + +import pytest +import torch +import torch.multiprocessing as mp +import yaml +from pydantic import BaseModel + +from modalities.__main__ import Main +from modalities.config.config import ProcessGroupBackendType +from modalities.config.pydantic_if_types import PydanticFSDP2ModuleType, PydanticPipelineType +from modalities.models.parallelism.pipeline_parallelism import Pipeline +from tests.end2end_tests.custom_components import MultiProcessingCudaEnv + + +@pytest.fixture +def temp_file_path() -> Path: + # Create a NamedTemporaryFile that persists after closing (delete=False) + with tempfile.NamedTemporaryFile(delete=False) as tf: + file_path = tf.name + try: + yield Path(file_path) + finally: + # Clean up the file after the test + if os.path.exists(file_path): + os.remove(file_path) + + +class ComponentsInstantiationModel(BaseModel): + initialized_model: PydanticFSDP2ModuleType + scheduled_pipeline: PydanticPipelineType + + +@pytest.mark.skipif( + torch.cuda.device_count() < 8, + reason="This test requires 8 GPUs", +) +class TestPipelineParallelism: + def _get_tmp_sharding_config_path( + self, sharding_degree: int, tp_degree: int, pp_degree: int, temp_file_path: Path + ) -> Path: + working_dir = Path(os.path.dirname(__file__)) + config_file_path = working_dir / "configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml" + + with open(config_file_path, "r") as file: + config_string = file.read() + config_dict = yaml.safe_load(config_string) + config_dict["device_mesh"]["config"]["data_parallel_shard_degree"] = sharding_degree + config_dict["device_mesh"]["config"]["tensor_parallel_degree"] = tp_degree + config_dict["device_mesh"]["config"]["pipeline_parallel_degree"] = pp_degree + + # save to temporary file + with open(temp_file_path, "w") as file: + yaml.dump(config_dict, file) + + return temp_file_path + + def _get_components(self, config_file_path: Path) -> ComponentsInstantiationModel: + main_obj = Main(config_file_path) + components: ComponentsInstantiationModel = main_obj.build_components( + components_model_type=ComponentsInstantiationModel + ) + return components + + @pytest.mark.parametrize( + "sharding_degree, tp_degree, pp_degree, world_size", + [ + (2, 1, 2, 4), + # (2, 1, 4, 8), + # (2, 2, 2, 8), # TODO need to support this case + ], + ) + def test_pp(self, sharding_degree: int, tp_degree: int, pp_degree: int, world_size: int, temp_file_path: Path): + tmp_sharding_config_path = self._get_tmp_sharding_config_path( + sharding_degree=sharding_degree, + tp_degree=tp_degree, + pp_degree=pp_degree, + temp_file_path=temp_file_path, + ) + mp.spawn( + self._test_pp_impl, + args=(world_size, tmp_sharding_config_path), + nprocs=world_size, + join=True, + ) + + def _test_pp_impl( + self, + process_id: int, + world_size: int, + gpt2_model_config_path: Path, + ): + # wraps the actual test function to be able to run it in a distributed multiprocessing setup + with MultiProcessingCudaEnv( + process_group_backend=ProcessGroupBackendType.nccl, + global_rank=process_id, + local_rank=process_id, + world_size=world_size, + rdvz_port=22356, + ): + components = self._get_components(gpt2_model_config_path) + scheduled_pipeline = components.scheduled_pipeline + vocab_size = 50304 + sequence_length = 256 + batch_size = 4 + sequences = torch.randint(0, vocab_size, (batch_size, sequence_length)) + targets = sequences[:, 1:].contiguous() + inputs = sequences[:, :-1].contiguous() + self._forward_step(scheduled_pipeline, inputs, targets) + + def _forward_step(self, scheduled_pipeline: Pipeline, inputs: torch.Tensor, targets: torch.Tensor): + """Runs a forward step on the model.""" + pp_schedule = scheduled_pipeline.pp_schedule + targets, losses = (targets, []) if scheduled_pipeline.is_last_pp_stage else (None, None) + if scheduled_pipeline.is_first_pp_stage: # first stage + pp_schedule.step(inputs, target=targets, losses=losses, input_batch=inputs) + else: # non-first stage + pp_schedule.step(target=targets, losses=losses, input_batch=inputs) + + # accumulate losses across pipeline microbatches + # TODO: PP+FSDP unexpectedly puts the loss back to the CPU + ( + torch.mean(torch.stack(losses)).to(self.device) + if self.pp_has_last_stage + else torch.tensor([-1.0], device=self.device) + ) + + # return output