Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove model_state.use_fp8_ddp and optimizer.all_reduce_grads #145

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions msamp/nn/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from msamp.common.tensor import ScalingTensor, ScalingMeta
from msamp.common.dtype import Dtypes, Floating
from msamp.common.utils import TransformerEngineWrapper
from msamp.nn.state import model_state
from msamp.operators.dist_op import DistOp


Expand Down Expand Up @@ -246,7 +245,6 @@ def __init__(self, module, **kwargs):
self.scaling_tensor_reducer = _ScalingTensorReducer(
scaling_params, self.process_group, self.bucket_bytes_cap
)
model_state.use_fp8_ddp = True

def forward(self, *inputs, **kwargs):
"""Apply _DDPSink in forward function.
Expand All @@ -255,7 +253,7 @@ def forward(self, *inputs, **kwargs):
inputs (tuple): The input tensors.
kwargs (dict): The keyword arguments.
"""
if model_state.use_fp8_ddp and torch.is_grad_enabled():
if torch.is_grad_enabled():
inputs = _DDPSink.apply(self.scaling_tensor_reducer, torch.tensor([], requires_grad=True), *inputs)
out = super().forward(*inputs, **kwargs)
return out
Expand Down
1 change: 1 addition & 0 deletions msamp/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def backward(ctx, output_grad):
use_split_accumulator=True,
)
del old_wgrad

if hasattr(ctx, 'return_wgrad') and ctx.return_wgrad:
wgrad = wgrad.cast(Dtypes.kfloat8_e4m3, meta=wgrad_meta, sync=True)
wgrad = wgrad.value.view(-1).view(dtype=torch.float32)
Expand Down
4 changes: 2 additions & 2 deletions msamp/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ def replace(cls, model, weight_qtype=Dtypes.kfloat16, src_rank=0, group=None):
for k, p in fp8_named_weights:
p._param_name = k

# DDP ignores the FP8 weights, and the optimizer provides a function `optimizer.all_reduce_grads(model)`
# to sync them.
# The native DDP ignores the FP8 weights,
# and msamp.nn.distributed.FP8DistributedDataParallel will handle them.
fp8_names = []
for module_name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
Expand Down
15 changes: 0 additions & 15 deletions msamp/nn/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def __init__(self):
# OrderedDict[str, dict[str, ScalingMeta]], store the local scaling metas in all FP8Linear modules.
# key is module name, value is scaling_metas in FP8Linear module.
self._local_scaling_metas = OrderedDict()
self._use_fp8_ddp = False

@property
def ready_to_scale_tensor(self):
Expand All @@ -42,20 +41,6 @@ def flattened_scaling_metas(self):
"""Decoration function to access _flattened_scaling_metas variable."""
return self._flattened_scaling_metas

@property
def use_fp8_ddp(self):
"""Decoration function to access _use_fp8_ddp variable."""
return self._use_fp8_ddp

@use_fp8_ddp.setter
def use_fp8_ddp(self, value):
"""Set the value of _use_fp8_ddp variable.

Args:
value (bool): Value to set.
"""
self._use_fp8_ddp = value

@flattened_scaling_metas.setter
def flattened_scaling_metas(self, value):
"""Set the value of _flattened_scaling_metas variable.
Expand Down
11 changes: 1 addition & 10 deletions msamp/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@

from msamp.common.dtype import Floating
from msamp.common.tensor import ScalingTensor, ScalingMeta
from msamp.common.tensor import TensorDist
from msamp.nn import model_state, ScalingParameter
from msamp.nn import model_state


class LBOptimizer(Optimizer):
Expand Down Expand Up @@ -42,14 +41,6 @@ def step(self, closure=None):
self._update_scaling_factors()
return rtn

def all_reduce_grads(self, model):
"""All-reduce gradients of parameters."""
if model_state.use_fp8_ddp:
return
scaling_params = [p for p in model.parameters() if isinstance(p, ScalingParameter)]
grads = [p.grad for p in scaling_params if p.grad is not None]
TensorDist.all_reduce_avg(grads)

def lb_step(self, closure=None):
"""Performs a single optimization step. The subclass needs to implement this method.

Expand Down
2 changes: 1 addition & 1 deletion tests/nn/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_fp8linear_backward(self):
self.assertTrue(torch.equal(fp8linear.bias.grad, linear.bias.grad))

# check weight.
self.assertTrue(isinstance(fp8linear.weight.grad, ScalingTensor))
self.assertTrue(isinstance(fp8linear.weight.grad, torch.Tensor))
self.assertTrue(fp8linear.weight.grad.size() == linear.weight.grad.size())

@decorator.cuda_test
Expand Down
33 changes: 1 addition & 32 deletions tests/optim/test_adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from functools import partial

from msamp.common.dtype import Dtypes
from msamp.common.tensor import TensorDist, ScalingTensor
from msamp.common.tensor import ScalingTensor
from msamp.optim import LBAdamW, LBAdam, LBAdamWBase, DSAdam
from msamp.nn import LinearReplacer
from tests.helper import decorator
Expand Down Expand Up @@ -78,37 +78,11 @@ def check_optimizer_step(self, optimizer_class1, optimizer_class2, diff=3e-4):
for _ in range(steps):
output = model2(input)
output.sum().backward()
opt2.all_reduce_grads(model2)
opt2.step()
opt2.zero_grad()

self.assertTrue(torch.allclose(model1.weight, model2.weight.float(), 0, diff))

def test_all_reduce_grads(self):
"""Test the function `all_reduce_grads`."""
input = torch.randn(4, 4, device='cuda')
model1 = torch.nn.Linear(4, 4).cuda()
model2 = torch.nn.Linear(4, 4).cuda()
model1 = LinearReplacer.replace(model1, Dtypes.kfloat16)
model2 = LinearReplacer.replace(model2, Dtypes.kfloat16)
opt = LBAdamW(list(model1.parameters()) + list(model2.parameters()))
loss = (model1(input) + model2(input)).sum()
loss.backward()
old_all_reduce_avg = TensorDist.all_reduce_avg
num_grads = 0

def debug_all_reduce_avg(grads):
nonlocal num_grads
num_grads += len(grads)
return old_all_reduce_avg(grads)

TensorDist.all_reduce_avg = debug_all_reduce_avg
opt.all_reduce_grads(model1)
self.assertEqual(num_grads, 1)
opt.all_reduce_grads(model2)
self.assertEqual(num_grads, 2)
TensorDist.all_reduce_avg = old_all_reduce_avg

def check_optimizer_state_dict(self, lbadam_class):
"""Save and load state dict of lbadam_class optimizer and check if the value is excepted.

Expand All @@ -127,7 +101,6 @@ def check_optimizer_state_dict(self, lbadam_class):
output = model1(input)
opt1.zero_grad()
output.sum().backward()
opt1.all_reduce_grads(model1)
opt1.step()

state_dict1 = opt1.state_dict()
Expand Down Expand Up @@ -158,7 +131,6 @@ def check_optimizer_state_dict(self, lbadam_class):
state_dict2 = copy.deepcopy(state_dict1)
opt1.zero_grad()
model1(input).sum().backward()
opt1.all_reduce_grads(model1)
opt1.step()

# Build model2 and update 4 times.
Expand All @@ -171,7 +143,6 @@ def check_optimizer_state_dict(self, lbadam_class):
output = model2(input)
opt2.zero_grad()
output.sum().backward()
opt2.all_reduce_grads(model2)
opt2.step()

# Load state dict to op2 and check if the weight is same as model1 after update weigth once.
Expand All @@ -180,7 +151,6 @@ def check_optimizer_state_dict(self, lbadam_class):

opt2.zero_grad()
model2(input).sum().backward()
opt2.all_reduce_grads(model2)
opt2.step()

self.assertTrue(torch.equal(model1.weight.value, model2.weight.value))
Expand Down Expand Up @@ -216,5 +186,4 @@ def test_historical_window_quantization(self):
y = model(x)
self.assertTrue((model.scaling_metas['input'].amax.max() == max(windows)).all())
y.sum().backward()
opt.all_reduce_grads(model)
opt.step()