Skip to content

Commit

Permalink
enforce recompute flag on fsdpa quantization (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
dudilester authored and astachowiczhabana committed Apr 19, 2024
1 parent b79036e commit b8c073a
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import math
import warnings
from typing import List, Optional, Tuple, Union
Expand Down Expand Up @@ -414,7 +415,8 @@ def pre_attn_forward(

if q_len == 1:
# next token
with ht.sdp_kernel(enable_recompute=False):
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
Expand Down

0 comments on commit b8c073a

Please sign in to comment.