Skip to content

Commit

Permalink
Fixed an enum value comparision error (bigscience-workshop#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
rraminen committed Jun 8, 2023
1 parent d5c822e commit 7491937
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions megatron/model/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ def forward(self, input, mask):
custom_kernel_constraint and self.scaled_masked_softmax_fusion:
scale = self.scale if self.scale is not None else 1.0

if self.attn_mask_type == AttnMaskType.causal:
if self.attn_mask_type.value == AttnMaskType.causal.value:
assert query_seq_len == key_seq_len, \
"causal mask is only for self attention"
input = input.view(-1, query_seq_len, key_seq_len)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
probs = probs.view(*data_size)
else:
assert self.attn_mask_type == AttnMaskType.padding
assert self.attn_mask_type.value == AttnMaskType.padding.value
probs = ScaledMaskedSoftmax.apply(input, mask, scale)
else:
if self.input_in_float16 and self.softmax_in_fp32:
Expand Down

0 comments on commit 7491937

Please sign in to comment.