Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Megatron Export Update #5423

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from nemo.core.classes.common import PretrainedModelInfo
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.mixins import AccessMixin
from nemo.utils import cast_all, logging, model_utils
from nemo.utils import logging, model_utils
from nemo.utils.cast_utils import cast_all

__all__ = ['ASRModel']

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from nemo.collections.nlp.modules.common.megatron.megatron_export import DecEmb, EncEmb, TokensHeadEmb
from nemo.collections.nlp.parts.nlp_overrides import GlobalBatchDataFetcher
from nemo.collections.nlp.parts.utils_funcs import get_last_rank
from nemo.core.classes import Exportable
from nemo.utils import AppState, logging, timers

try:
Expand All @@ -56,7 +57,7 @@
__all__ = ["MegatronNMTModel"]


class MegatronNMTModel(MegatronLMEncoderDecoderModel):
class MegatronNMTModel(MegatronLMEncoderDecoderModel, Exportable):
"""
Megatron NMT training
"""
Expand Down Expand Up @@ -750,5 +751,12 @@ def decoder(self):
return DecEmb(self.enc_dec_model.decoder_embedding, self.enc_dec_model.enc_dec_model.decoder, self.device)

@property
def classifier(self):
def log_softmax(self):
return TokensHeadEmb(self.enc_dec_model.decoder_embedding, self.enc_dec_model.tokens_head, self.device)

@property
def input_module(self):
return self.encoder

def list_export_subnets(self):
return ['encoder', 'log_softmax', 'decoder']
50 changes: 33 additions & 17 deletions nemo/collections/nlp/modules/common/megatron/megatron_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,23 @@ def forward(self, dec_output):
if isinstance(dec_output, list):
dec_output = dec_output[0]

dec_output = torch.permute(dec_output, (1, 0, 2))

if self.tokens_head_bias is not None:
return F.linear(dec_output, self.decoder_embedding.word_embeddings.weight, self.tokens_head_bias)
return F.linear(dec_output, self.decoder_embedding.word_embeddings.weight)

def input_example(self, max_batch=1, max_dim=1024, seq_len=6):
def input_example(self, max_batch=1, max_dim=768, seq_len=6):
return [
torch.randint(low=-3, high=3, size=(seq_len, max_batch, max_dim), device=self.device, dtype=torch.float32)
torch.randint(low=-3, high=3, size=(max_batch, seq_len, max_dim), device=self.device, dtype=torch.float32)
]

def freeze(self):
for param in self.parameters():
param.requires_grad = False

@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
return {
"hidden_states": NeuralType(('T', 'B', 'D'), ChannelType()),
"hidden_states": NeuralType(('B', 'T', 'D'), ChannelType()),
}

@property
Expand Down Expand Up @@ -107,18 +109,28 @@ def forward(self, input_ids, decoder_mask, encoder_mask, encoder_embeddings, dec
# dec_input, dec_attn_mask, enc_output, enc_attn_mask | dec_input, dec_attn_mask, enc_output, enc_attn_mask
_ = dec_mems

return self.decoder(dec_input, decoder_mask, encoder_embeddings, encoder_mask).float()
return (
self.decoder(dec_input, decoder_mask, encoder_embeddings.permute(1, 0, 2), encoder_mask)
.float()
.permute(1, 0, 2)
)

def input_example(self, max_batch=1, max_dim=1024, seq_len=6):
def freeze(self):
for param in self.parameters():
param.requires_grad = False

def input_example(self, max_batch=1, max_dim=768, seq_len=6):
enc_output = torch.randint(
low=-3, high=3, size=(seq_len, max_batch, max_dim), device=self.device, dtype=torch.float32
low=-3, high=3, size=(max_batch, seq_len, max_dim), device=self.device, dtype=torch.float32
)
enc_attn_mask = torch.tensor([[1 for _ in range(seq_len)]]).to(self.device)

dec_len = random.randint(10, 128)
dec_input = torch.randint(low=0, high=1000, size=(max_batch, dec_len), device=self.device)
dec_attn_mask = torch.tensor([[1 for _ in range(dec_len)]]).to(self.device)
decoder_mems = torch.zeros([8, 6, 1024], dtype=torch.float32).to(self.device)

# constant decoder_mems as placeholder for now
decoder_mems = torch.zeros([8, 6, max_dim], dtype=torch.float32).to(self.device)

# input_ids, decoder_mask, encoder_mask, encoder_embeddings
return (dec_input, dec_attn_mask, enc_attn_mask, enc_output, decoder_mems)
Expand All @@ -128,14 +140,14 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]:
return {
"input_ids": NeuralType(('B', 'T', 'D'), ChannelType()),
"decoder_mask": NeuralType(('B', 'T'), MaskType()),
"encoder_mask": NeuralType(('T', 'B', 'D'), ChannelType()),
"encoder_mask": NeuralType(('B', 'T', 'D'), ChannelType()),
"encoder_embeddings": NeuralType(('B', 'T'), MaskType()),
"decoder_mems": NeuralType(('T', 'B', 'D'), ChannelType()),
"decoder_mems": NeuralType(('B', 'T', 'D'), ChannelType()),
}

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {"last_hidden_states": NeuralType(('T', 'B', 'D'), ChannelType())}
return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())}

@property
def input_names(self) -> List[str]:
Expand Down Expand Up @@ -172,15 +184,19 @@ def forward(self, input_ids, encoder_mask):
enc_input = self.encoder_embedding(input_ids, position_ids, token_type_ids=None)

# pass input through the encoder
return self.encoder(enc_input=enc_input, enc_attn_mask=encoder_mask,).type(torch.float32)
return self.encoder(enc_input=enc_input, enc_attn_mask=encoder_mask,).permute(1, 0, 2)

def input_example(self):
def input_example(self, max_batch=1, max_dim=30000, seq_len=6):
seq_len = random.randint(0, 128)
return (
torch.randint(0, 30000, (1, seq_len)).to(self.device),
torch.ones((1, seq_len), dtype=int).to(self.device),
torch.randint(0, max_dim, (max_batch, seq_len)).to(self.device),
torch.ones((max_batch, seq_len), dtype=int).to(self.device),
)

def freeze(self):
for param in self.parameters():
param.requires_grad = False

@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
return {
Expand All @@ -190,7 +206,7 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]:

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {"last_hidden_states": NeuralType(('T', 'B', 'D'), ChannelType())}
return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())}

@property
def input_names(self) -> List[str]:
Expand Down
91 changes: 62 additions & 29 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,40 @@ def forward(self, x):
return F.linear(x, self.weight, self.bias), None


# ScaledMaskedSoftmax replacement
def mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores


def exportable_ScaledMaskedSoftmax(input, mask, scale):
if scale is not None:
input = input * scale

mask_output = mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)

probs = probs.half()
return probs
class ExportableMatchedScaleMaskSoftmax(nn.Module):
def __init__(self, mod):
super(ExportableMatchedScaleMaskSoftmax, self).__init__()
self.init_module(mod.input_in_fp16, mod.input_in_bf16, mod.mask_func, mod.softmax_in_fp32, mod.scale)

def init_module(
self, input_in_fp16, input_in_bf16, mask_func, softmax_in_fp32, scale,
):
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
self.softmax_in_fp32 = softmax_in_fp32
self.mask_func = mask_func
self.scale = scale

self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16

def forward(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()

if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
all_k_masked = mask.all(axis=-1)
zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None]
probs = probs * zero_attention_mask

if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs


def get_export_format(filename: str):
Expand Down Expand Up @@ -159,10 +178,12 @@ def verify_runtime(model, output, input_examples, input_names, check_tolerance=0
logging.warning(f"ONNX generated at {output}, not verified - please install onnxruntime_gpu package.\n")
onnx.checker.check_model(onnx_model, full_check=True)
return
del onnx_model
onnx_session_opt = onnxruntime.SessionOptions()
onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess = onnxruntime.InferenceSession(output, sess_options=onnx_session_opt, providers=['CUDAExecutionProvider'])
onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess = onnxruntime.InferenceSession(
onnx_model.SerializeToString(), sess_options=onnx_session_opt, providers=['CUDAExecutionProvider']
)
del onnx_model
all_good = True
for input_example in input_examples:
input_list, input_dict = parse_input_example(input_example)
Expand Down Expand Up @@ -227,27 +248,24 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):
from apex.transformer.tensor_parallel.layers import RowParallelLinear, ColumnParallelLinear
from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax

def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.BatchNorm2d]:
def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:
"""
Replaces Apex's FusedLayerNorm with nn.LayerNorm. This is required for ONNX export.
Args:
n: the FusedLayerNorm pytorch module to replace
Returns:
Equivalent LayerNorm module
"""
if (
not isinstance(n, FusedLayerNorm)
and not isinstance(n, FastLayerNorm)
and not isinstance(n, MixedFusedLayerNorm)
):
return None

dev = next(n.parameters()).device
p = next(n.parameters())
if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm):
mod = nn.LayerNorm(n.normalized_shape, eps=n.eps, elementwise_affine=n.elementwise_affine,).to(dev)
shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine
elif isinstance(n, FastLayerNorm):
mod = nn.LayerNorm(n.weight.shape, eps=n.epsilon, elementwise_affine=True, dtype=torch.float16,).to(dev)
shape, eps, affine = n.weight.shape, n.epsilon, True
else:
return None

mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype)
n_state = n.state_dict()
mod.load_state_dict(n_state)
return mod
Expand All @@ -264,7 +282,7 @@ def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]:
raise ValueError("This function can only change the RowParallelLinear module.")

dev = next(n.parameters()).device
mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev)
mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(device=dev)

n_state = n.state_dict()
mod.load_state_dict(n_state)
Expand Down Expand Up @@ -340,6 +358,20 @@ def expansion_fn(mod: nn.Module) -> Optional[nn.Module]:
return expansion_fn


def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
"""
Replaces MatchedScaleMaskSoftmax with exportable softmax layer
Args:
n: module to replace
Returns:
exportable module
"""

mod = ExportableMatchedScaleMaskSoftmax(n.input_in_fp16, n.input_in_bf16, n.mask_func, n.softmax_in_fp32, n.scale)

Check failure

Code scanning / CodeQL

Wrong number of arguments in a class instantiation

Call to [ExportableMatchedScaleMaskSoftmax.__init__](1) with too many arguments; should be no more than 1.

return mod


def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]:
"""
Generic function generator to replace BaseT module with DestT wrapper.
Expand Down Expand Up @@ -408,6 +440,7 @@ def script_module(m: nn.Module):
"BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat),
"BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat),
"LayerNorm": wrap_module(nn.LayerNorm, CastToFloat),
"MatchedScaleMaskSoftmax": wrap_module(nn.Softmax, ExportableMatchedScaleMaskSoftmax),
}

script_replacements = {
Expand Down
1 change: 1 addition & 0 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def nemo_export(argv):
with autocast(), torch.no_grad(), torch.inference_mode():
model.to(device=args.device).freeze()
model.eval()
input_example = None
if check_trace and len(in_args) > 0:
input_example = model.input_module.input_example(**in_args)
check_trace = [input_example]
Expand Down