Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FSDP for NeMo 2.0 #9748

Merged
merged 49 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
375b282
modify code structure and add strategy
blahBlahhhJ Jul 15, 2024
7de4000
correct doc url
blahBlahhhJ Jul 15, 2024
73f547f
Apply isort and black reformatting
blahBlahhhJ Jul 15, 2024
708221a
oextract common elements and add callback
blahBlahhhJ Jul 15, 2024
a8340ca
Apply isort and black reformatting
blahBlahhhJ Jul 15, 2024
91563b0
add iomixin
blahBlahhhJ Jul 18, 2024
c8e8182
Apply isort and black reformatting
blahBlahhhJ Jul 18, 2024
9881054
update strategies
blahBlahhhJ Jul 31, 2024
a30dc6c
update callback
blahBlahhhJ Jul 31, 2024
563fb7d
add training step to strategy
blahBlahhhJ Jul 31, 2024
6ceb8f3
Apply isort and black reformatting
blahBlahhhJ Jul 31, 2024
e0ec40d
remove unused import
blahBlahhhJ Jul 31, 2024
c77e3b7
add iomixin to strategy & precision
blahBlahhhJ Jul 31, 2024
3049fac
Apply isort and black reformatting
artbataev Jul 31, 2024
610f602
add val/test steps to strategy
blahBlahhhJ Jul 31, 2024
c3fc092
add documentations
blahBlahhhJ Jul 31, 2024
1624b2c
Apply isort and black reformatting
blahBlahhhJ Jul 31, 2024
26bbf90
add default sharding for fsdp. add setup callback detection
blahBlahhhJ Aug 1, 2024
31b2626
Apply isort and black reformatting
blahBlahhhJ Aug 1, 2024
a3006ef
extract checkpoint io logic
blahBlahhhJ Aug 2, 2024
a17c99f
Apply isort and black reformatting
blahBlahhhJ Aug 2, 2024
4eb67a3
clean up unused imports
blahBlahhhJ Aug 5, 2024
2d08844
sync new megatron strategy changes
blahBlahhhJ Aug 8, 2024
3d43a75
break down setup callback back into strategy
blahBlahhhJ Aug 12, 2024
efdcd56
Apply isort and black reformatting
blahBlahhhJ Aug 12, 2024
24f6e7b
reorder stuff
blahBlahhhJ Aug 12, 2024
94a6bf1
Merge branch 'main' into jasonwan/fsdp
blahBlahhhJ Aug 12, 2024
6f14d1e
fix data logic
blahBlahhhJ Aug 12, 2024
5a524a9
minor fix
blahBlahhhJ Aug 12, 2024
0c3f657
add dtensor ckpt conversion support
blahBlahhhJ Aug 13, 2024
220a4d5
Apply isort and black reformatting
blahBlahhhJ Aug 13, 2024
26b9849
support hsdp
blahBlahhhJ Aug 13, 2024
54c12a9
Apply isort and black reformatting
blahBlahhhJ Aug 13, 2024
881a2bb
Merge branch 'main' into jasonwan/fsdp
blahBlahhhJ Aug 15, 2024
5afe7b6
remove iomixin
blahBlahhhJ Aug 15, 2024
5263527
Apply isort and black reformatting
blahBlahhhJ Aug 15, 2024
29106c3
Merge branch 'main' into jasonwan/fsdp
blahBlahhhJ Aug 16, 2024
871785a
Merge branch 'main' into jasonwan/fsdp
blahBlahhhJ Aug 19, 2024
c01ebb1
Apply isort and black reformatting
blahBlahhhJ Aug 19, 2024
d3137d6
Merge branch 'main' into jasonwan/fsdp
blahBlahhhJ Aug 21, 2024
20ccd35
Merge branch 'main' into jasonwan/fsdp
blahBlahhhJ Aug 26, 2024
438502a
fix import
blahBlahhhJ Aug 27, 2024
94f5faa
refactor loss reduction
blahBlahhhJ Aug 28, 2024
be75983
Apply isort and black reformatting
blahBlahhhJ Aug 28, 2024
d4f74f5
Merge branch 'main' into jasonwan/fsdp
blahBlahhhJ Aug 28, 2024
97c5e4d
clean up
blahBlahhhJ Aug 28, 2024
0ac8f9a
add unittest
blahBlahhhJ Aug 28, 2024
ed5a0d9
Apply isort and black reformatting
blahBlahhhJ Aug 28, 2024
06afb25
clean up
blahBlahhhJ Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from nemo.lightning.pytorch.optim import LRSchedulerModule, MegatronOptimizerModule, OptimizerModule, lr_scheduler
from nemo.lightning.pytorch.plugins import MegatronDataSampler, MegatronMixedPrecision
from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler
from nemo.lightning.pytorch.strategies import MegatronStrategy
from nemo.lightning.pytorch.strategies import FSDPStrategy, MegatronStrategy
from nemo.lightning.pytorch.trainer import Trainer
from nemo.lightning.resume import AutoResume

Expand Down
8 changes: 8 additions & 0 deletions nemo/lightning/pytorch/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from nemo.lightning.pytorch.strategies.fsdp_strategy import FSDPStrategy
from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy


__all__ = [
"FSDPStrategy",
"MegatronStrategy",
]
245 changes: 245 additions & 0 deletions nemo/lightning/pytorch/strategies/fsdp_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import shutil
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict, Optional, Union

import pytorch_lightning as pl
import torch
from lightning_fabric.plugins import CheckpointIO
from lightning_fabric.strategies.fsdp import _get_sharded_state_dict_context
from megatron.core.transformer.transformer_layer import TransformerLayer
from pytorch_lightning.strategies.fsdp import FSDPStrategy as PLFSDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch.distributed.checkpoint.state_dict import ( # get_state_dict,
StateDictOptions,
get_optimizer_state_dict,
set_state_dict,
)
from torch.utils.data import DataLoader
from typing_extensions import override

from nemo.lightning import io
from nemo.lightning.pytorch.strategies.utils import (
ckpt_to_dir,
fix_progress_bar,
get_checkpoint_io,
init_model_parallel,
mcore_to_pyt_sharded_state_dict,
pyt_to_mcore_state_dict,
setup_data_sampler,
setup_parallel_ranks,
)


class FSDPStrategy(PLFSDPStrategy, io.IOMixin):
"""Megatron plugin for Pytorch Lightning.

This strategy implements Fully-Sharded-Data-Parallel using PyTorch's native FSDP methods.
Comparing with MegatronStrategy, FSDPStrategy is designed to be more lightweight, with
minimal modifications over Lightning's FSDPStrategy but preserves necessary features to be
compatible with nemo and mcore.
By default, this strategy wraps FSDP per TransformerLayer.

Note:
This strategy is designed to work with NVIDIA's Megatron-LM framework and requires
specific model implementations that are compatible with Megatron's parallelism techniques.
Note:
Due to the different optimizer structure (FSDP only uses torch native optimizers),
MegatronStrategy cannot resume training from checkpoints saved by FSDPStrategy, and vice
versa. However, the model weights structure is made compatible, so switching strategy is
possible if users only need the weights not the optimizer states. (E.g. run pretrain with
megatron 4D parallelism and run SFT with FSDP.)
"""

def __init__(
self,
auto_wrap_policy={TransformerLayer},
state_dict_type="sharded",
ckpt_include_optimizer=False,
data_sampler=None,
**kwargs,
):
super().__init__(auto_wrap_policy=auto_wrap_policy, state_dict_type=state_dict_type, **kwargs)

self.data_sampler = data_sampler
self.ckpt_include_optimizer = ckpt_include_optimizer

@override
def setup_environment(self) -> None:
setup_parallel_ranks(self)
super().setup_environment()
init_model_parallel(self.model)

@override
def setup(self, trainer: pl.Trainer) -> None:
self.trainer = trainer
setup_data_sampler(self.trainer)
fix_progress_bar(trainer)
super().setup(trainer)

def _get_loss_reduction(self, step_type: str):
for fn_name in [f"{step_type}_loss_reduction", "loss_reduction"]:
if hasattr(self.lightning_module, fn_name):
return getattr(self.lightning_module, fn_name)
return None

def _step_proxy(self, step_type, batch, batch_idx=None):
method_name = f"{step_type}_step"
if self.model != self.lightning_module:
loss = self._forward_redirection(self.model, self.lightning_module, method_name, batch, batch_idx)
else:
loss = getattr(self.lightning_module, method_name)(batch, batch_idx)

_loss_reduction = self._get_loss_reduction(step_type)
if _loss_reduction:
return _loss_reduction.forward(batch, loss)
return loss, {'avg': loss}

@override
def training_step(self, batch, batch_idx=None) -> STEP_OUTPUT:
assert self.lightning_module is not None
assert self.model is not None
with self.precision_plugin.train_step_context():
loss, reduced = self._step_proxy("training", batch, batch_idx)

self.lightning_module.log(
'global_step',
self.trainer.global_step,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
)

self.lightning_module.log(
'step',
self.trainer.global_step,
)
self.lightning_module.log(
'reduced_train_loss', reduced['avg'], prog_bar=True, rank_zero_only=True, batch_size=1
)

# returns unreduced loss for backward
return loss

@override
def validation_step(self, batch, batch_idx=None) -> Any:
assert self.lightning_module is not None
assert self.model is not None
with self.precision_plugin.val_step_context():
loss, reduced = self._step_proxy("validation", batch, batch_idx)
self.lightning_module.log('val_loss', reduced['avg'], rank_zero_only=True, batch_size=1)
return loss

@override
def test_step(self, batch, batch_idx=None) -> STEP_OUTPUT:
assert self.lightning_module is not None
assert self.model is not None
with self.precision_plugin.test_step_context():
loss, reduced = self._step_proxy("test", batch, batch_idx)
self.lightning_module.log('test_loss', reduced['avg'], rank_zero_only=True, batch_size=1)

return loss

@override
def predict_step(self, batch, batch_idx=None) -> STEP_OUTPUT:
assert self.lightning_module is not None
assert self.model is not None
with self.precision_plugin.predict_step_context():
loss, reduced = self._step_proxy("predict", batch, batch_idx)
return reduced

@override
def process_dataloader(self, dataloader: DataLoader) -> DataLoader:
if self.data_sampler:
return self.data_sampler.transform_dataloader(dataloader)

return dataloader

@property
@override
def checkpoint_io(self) -> CheckpointIO:
return get_checkpoint_io(self._checkpoint_io)

@checkpoint_io.setter
def checkpoint_io(self, io: CheckpointIO) -> None:
self._checkpoint_io = io

@property
def current_epoch_step(self) -> int:
"""
Get the value of step within an epoch.
"""
return max(
self.trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.current.completed,
self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.current.completed,
)

@override
def remove_checkpoint(self, filepath: Union[str, Path]) -> None:
# Taken from MegatronStrategy
if self.is_global_zero:
shutil.rmtree(ckpt_to_dir(filepath))

@override
def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None
) -> None:
"""Converts PyT checkpoints to MCore format and save using MCore dist ckpt library."""
checkpoint["sharded_state_dict"] = pyt_to_mcore_state_dict(checkpoint.pop("state_dict"))
checkpoint["state_dict"] = OrderedDict([])

# TODO: do we still need to keep this?
for optim_state in checkpoint['optimizer_states']:
optim_state.pop("state")

if self.trainer.state.fn == TrainerFn.FITTING and self.ckpt_include_optimizer:
checkpoint['optimizer'] = get_optimizer_state_dict(self.model, self.optimizers)
pyt_to_mcore_state_dict(checkpoint['optimizer']['state'], prefix="optimizer.state.")

self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)

@override
def load_checkpoint(self, checkpoint_path: str | Path) -> Dict[str, Any]:
"""PTL method which we override to integrate distributed checkpoints for FSDP models.
Different from MegatronStrategy, both model and optimizer states are restore within
this method.

The logic here is slightly more complicated:
1. Obtain PyT state dicts (sharded & unflattened) for model and optim -> torch::ShardedTensor
2. Convert to MCore state dicts -> mcore::ShardedTensor
3. Load from checkpoint using MCore dist ckpt API -> torch::Tensor
4. Convert to PyT state dicts (sharded & unflattened) -> torch::ShardedTensor
5. Load into model and optim using PyT dist ckpt API
6. Return the loaded checkpoint for lightning to load other metadata
"""
path = Path(self.broadcast(checkpoint_path))
torch.cuda.empty_cache()

# TODO: the elegant way to load both state dicts. Need pytorch 2.3.1
# msd, osd = get_state_dict(self.model, self.optimizers, options=StateDictOptions(cpu_offload=True))
sharded_state_dict = {}
with _get_sharded_state_dict_context(self.model):
msd = self.model.state_dict()
pyt_to_mcore_state_dict(msd)
sharded_state_dict["sharded_state_dict"] = msd

if self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING:
osd = get_optimizer_state_dict(self.model, self.optimizers, options=StateDictOptions(cpu_offload=True))
pyt_to_mcore_state_dict(osd['state'], prefix="optimizer.state.")
sharded_state_dict["optimizer"] = osd

checkpoint = self.checkpoint_io.load_checkpoint(path, sharded_state_dict=sharded_state_dict)
mcore_to_pyt_sharded_state_dict(checkpoint['sharded_state_dict'], msd)

if self.ckpt_include_optimizer and self.trainer.state.fn == TrainerFn.FITTING:
mcore_to_pyt_sharded_state_dict(checkpoint['optimizer']['state'], osd['state'])

set_state_dict(
self.model,
self.optimizers if self.ckpt_include_optimizer else [],
model_state_dict=checkpoint['sharded_state_dict'],
optim_state_dict=checkpoint['optimizer'] if self.ckpt_include_optimizer else None,
)

return checkpoint
Loading
Loading