2525import torch_xla2
2626from jax import lax
2727from jetstream_pt import torchjax
28+ from jetstream_pt .environment import QuantizationConfig
2829from jetstream_pt .quantize import (
2930 dequantize_tensor ,
3031 load_q_weight_helper ,
3132 quantize_tensor ,
33+ blockwise_jax_kernel ,
34+ blockwise_jax_kernel_dot_general ,
35+ blockwise_jax_kernel_einsum_flatten ,
3236)
3337from torch import nn
3438from . import attention_kernel as ak
@@ -68,8 +72,7 @@ def __init__(
6872 out_features ,
6973 bias = False ,
7074 device = None ,
71- is_symmetric = True ,
72- n_bit = 8 ,
75+ quant_config = QuantizationConfig (),
7376 ):
7477 super ().__init__ ()
7578 self .in_features = in_features
@@ -85,8 +88,9 @@ def __init__(
8588 )
8689 self .register_buffer ("weight_scaler" , weight_scaler )
8790
88- self .is_symmetric = is_symmetric
89- if not is_symmetric :
91+ self .is_symmetric_weight = quant_config .is_symmetric_weight
92+
93+ if not self .is_symmetric_weight :
9094 zero_point = torch .ones (
9195 (out_features ,), dtype = torch .bfloat16 , device = device
9296 )
@@ -96,7 +100,12 @@ def __init__(
96100
97101 assert not bias , "Quantized Linear doesn't support bias."
98102
99- self .n_bit = n_bit
103+ # Number of bits of weight tensor
104+ self .n_bit = quant_config .num_bits_weight
105+
106+ # Quantize activation
107+ self .quantize_activation = quant_config .enable_activation_quantization
108+
100109 # Flag to enable dequantize weight first, then do matmul. Useful for debugging.
101110 self .run_fake_quantize = False
102111
@@ -115,23 +124,40 @@ def quantize_weight_from_nn_linear(self, weight):
115124 self .in_features ,
116125 ), f"Got unexpected weight of shape { weight .shape } , expected weight shape ({ self .out_features } , { self .in_features } )."
117126 w_q , scale , zp = quantize_tensor (
118- weight , (1 ,), self .n_bit , self .is_symmetric , block_size = - 1
127+ weight , (1 ,), self .n_bit , self .is_symmetric_weight , block_size = - 1
119128 )
120129 w_dq = dequantize_tensor (w_q , scale , zp )
121130 self ._load_quantized_weights (w_q , scale , zp )
122131
123132 def forward (self , inputs ):
124133 if not self .run_fake_quantize :
125- if self .is_symmetric :
126- return torch .mul (F .linear (inputs , self .weight ), self .weight_scaler )
134+ if self .quantize_activation :
135+ inputs , act_s , _ = quantize_tensor (inputs , reduce_axis = (2 ,))
136+ if not self .quantize_activation :
137+ result = F .linear (inputs , self .weight )
127138 else :
128- out = torch .mul (F .linear (inputs , self .weight ), self .weight_scaler )
139+ # We have to call jax because we need to do dot(int8, int8)->int32.
140+ # This semantic cannot be represented in torch. The inferred output dtype
141+ # will be int8 in torch, causing the dot result to overflow.
142+ result = torchjax .call_jax (
143+ jax .lax .dot_general ,
144+ inputs ,
145+ self .weight ,
146+ (((2 ,), (1 )), ((), ())),
147+ None ,
148+ jnp .int32 .dtype ,
149+ )
150+ result = result * self .weight_scaler
151+ if self .quantize_activation :
152+ result = result * act_s
153+ if not self .is_symmetric_weight :
129154 zp_out = torch .einsum ("...c,z->...z" , inputs , self .zero_point )
130- return out - zp_out
155+ result = result - zp_out
156+ return result
131157 else :
132158 # Fake quantization, debugging purpose.
133159 scaler = self .weight_scaler .unsqueeze (- 1 )
134- if not self .is_symmetric :
160+ if not self .is_symmetric_weight :
135161 zero_point = self .zero_point .unsqueeze (- 1 ) / scaler
136162 else :
137163 zero_point = None
@@ -149,32 +175,37 @@ def __init__(
149175 out_features ,
150176 bias = False ,
151177 device = None ,
152- is_symmetric = True ,
153- use_dot_general = False ,
154- block_size = 128 ,
155- n_bit = 8 ,
178+ quant_config = QuantizationConfig (),
156179 ):
157180 super ().__init__ ()
158181 self .in_features = in_features
159182 self .out_features = out_features
160183
161184 # Use dot general instead of einsum
162185 # Use dot general is slow now.
163- self .use_dot_general = use_dot_general
186+ self .use_dot_general = False
164187 # Flatten einsum operands to 3D. XLA was slow if operands are 4D. But it's fixed now.
165188 # Same perf as non flattened one now.
166189 self .flatten = False
167190
168- self .block_size = block_size
169- n_blocks = in_features // block_size
191+ self .block_size = quant_config .block_size_weight
192+ n_blocks = in_features // self .block_size
193+
194+ assert (
195+ not quant_config .enable_activation_quantization
196+ ), "Activation quantization not supported for blockwise quantized matmul."
170197
171198 if self .use_dot_general :
172199 weight = torch .ones (
173- (n_blocks , out_features , block_size ), dtype = torch .int8 , device = device
200+ (n_blocks , out_features , self .block_size ),
201+ dtype = torch .int8 ,
202+ device = device ,
174203 )
175204 else :
176205 weight = torch .ones (
177- (n_blocks , block_size , out_features ), dtype = torch .int8 , device = device
206+ (n_blocks , self .block_size , out_features ),
207+ dtype = torch .int8 ,
208+ device = device ,
178209 )
179210 self .register_buffer ("weight" , weight )
180211
@@ -183,16 +214,20 @@ def __init__(
183214 )
184215 self .register_buffer ("weight_scaler" , weight_scaler )
185216
186- self .is_symmetric = is_symmetric
187- if not self .is_symmetric :
217+ self .is_symmetric_weight = quant_config . is_symmetric_weight
218+ if not self .is_symmetric_weight :
188219 zero_point = torch .ones (
189220 (n_blocks , out_features ), dtype = torch .bfloat16 , device = device
190221 )
191222 self .register_buffer ("zero_point" , zero_point )
192223 else :
193224 self .register_buffer ("zero_point" , None )
194225
195- self .n_bit = n_bit
226+ self .n_bit = quant_config .num_bits_weight
227+
228+ # Quantize activation
229+ self .quantize_activation = quant_config .enable_activation_quantization
230+
196231 # Flag to enable dequantize weight first, then do matmul. Useful for debugging.
197232 self .run_fake_quantize = False
198233
@@ -211,112 +246,37 @@ def quantize_weight_from_nn_linear(self, weight):
211246 self .in_features ,
212247 ), f"Unexpected weight shape ({ self .out_features } , { self .in_features } )."
213248 w_q , scale , zp = quantize_tensor (
214- weight , (1 ,), self .n_bit , self .is_symmetric , self .block_size
249+ weight , (1 ,), self .n_bit , self .is_symmetric_weight , self .block_size
215250 )
216251 w_dq = dequantize_tensor (w_q , scale , zp )
217- print ("check qweight cosine dist: " , _calc_cosine_dist (weight , w_dq ))
218- # breakpoint()
219252 self ._load_quantized_weights (w_q , scale , zp )
220253
221- @staticmethod
222- def blockwise_jax_kernel (inputs , weight , weight_scaler , zero_point ):
223- """Blockwise Matmul kernel impl in JAX using einsum"""
224- weight = weight .astype (jnp .int8 )
225- block_size = weight .shape [1 ]
226- inputs_shape = inputs .shape
227- inputs_new_shape = inputs_shape [:- 1 ] + (
228- inputs_shape [- 1 ] // block_size ,
229- block_size ,
230- )
231- inputs = inputs .reshape (inputs_new_shape )
232- out = jnp .einsum ("scz,bdsc->bdsz" , weight , inputs )
233- out = jnp .einsum ("bdsz,sz->bdz" , out , weight_scaler )
234- if zero_point is not None :
235- zp_out = jnp .einsum ("bdsc,sz->bdz" , inputs , zero_point )
236- out = out - zp_out
237- return out
238-
239- @staticmethod
240- def blockwise_jax_kernel_dot_general (
241- inputs , weight , weight_scaler , zero_point
242- ):
243- """Blockwise Matmul kernel impl in JAX using dot general"""
244- inputs_shape = inputs .shape
245- block_size = weight .shape [2 ]
246- bs = inputs_shape [0 ]
247- inputs_new_shape = inputs_shape [:- 1 ] + (
248- inputs_shape [- 1 ] // block_size ,
249- block_size ,
250- )
251- inputs = inputs .reshape (inputs_new_shape )
252- inputs = jax .lax .collapse (inputs , 0 , 2 )
253- out = jax .lax .dot_general (
254- inputs , weight , dimension_numbers = ([(2 ), (2 )], [(1 ), (0 )])
255- )
256- out = jax .lax .dot_general (
257- out , weight_scaler , dimension_numbers = ([(0 ), (0 )], [(2 ), (1 )])
258- )
259- out = jax .lax .transpose (out , [1 , 0 ])
260- out = out .reshape ((bs , - 1 ) + out .shape [1 :])
261- return out
262-
263- @staticmethod
264- def blockwise_jax_kernel_einsum_flatten (
265- inputs , weight , weight_scaler , zero_point
266- ):
267- """Blockwise Matmul kernel impl in JAX using einsum, with operands flattened"""
268- weight = weight .astype (jnp .int8 )
269- block_size = weight .shape [1 ]
270- inputs_shape = inputs .shape
271- bs = inputs_shape [0 ]
272- inputs_new_shape = inputs_shape [:- 1 ] + (
273- inputs_shape [- 1 ] // block_size ,
274- block_size ,
275- )
276- inputs = inputs .reshape (inputs_new_shape )
277- inputs = jax .lax .collapse (inputs , 0 , 2 )
278- out = jnp .einsum ("scz,bsc->bsz" , weight , inputs )
279- out = jnp .einsum ("bsz,sz->bz" , out , weight_scaler )
280- out = out .reshape ((bs , - 1 ) + out .shape [1 :])
281- return out
282-
283254 def forward (self , inputs ):
284255 if not self .run_fake_quantize :
285- if self .use_dot_general :
256+ if self .use_dot_general or self . flatten :
286257 assert (
287258 self .zero_point is None
288- ), "Blockwise quantized linear doesn't support zero_point in dot_general implementation."
289- return torchjax .call_jax (
290- WeightOnlyBlockwiseQuantizedLinear .blockwise_jax_kernel_dot_general ,
291- inputs ,
292- self .weight ,
293- self .weight_scaler ,
294- self .zero_point ,
295- )
296- if self .flatten :
297- assert (
298- self .zero_point is None
299- ), "Blockwise quantized linear doesn't support zero_point in einsum (flattened) implementation."
300- return torchjax .call_jax (
301- WeightOnlyBlockwiseQuantizedLinear .blockwise_jax_kernel_einsum_flatten ,
302- inputs ,
303- self .weight ,
304- self .weight_scaler ,
305- self .zero_point ,
306- )
307- else :
308- return torchjax .call_jax (
309- WeightOnlyBlockwiseQuantizedLinear .blockwise_jax_kernel ,
310- inputs ,
311- self .weight ,
312- self .weight_scaler ,
313- self .zero_point ,
314- )
259+ ), "Blockwise quantized linear doesn't support zero_point in dot_general or einsum flattened implementation."
260+ blockwise_matmul_kernel = (
261+ blockwise_jax_kernel
262+ if not self .use_dot_general and not self .flatten
263+ else blockwise_jax_kernel_dot_general
264+ if self .use_dot_general
265+ else blockwise_jax_kernel_einsum_flatten
266+ )
267+ result = torchjax .call_jax (
268+ blockwise_matmul_kernel ,
269+ inputs ,
270+ self .weight ,
271+ self .weight_scaler ,
272+ self .zero_point ,
273+ )
274+ return result
315275 else :
316276 # Fake quantization, debugging purpose.
317277 weight = self .weight .permute (2 , 0 , 1 ).to (torch .bfloat16 )
318278 scaler = self .weight_scaler .unsqueeze (- 1 ).transpose (1 , 0 )
319- if not self .is_symmetric :
279+ if not self .is_symmetric_weight :
320280 zero_point = self .zero_point .unsqueeze (- 1 ).transpose (1 , 0 ) / scaler
321281 else :
322282 zero_point = None
@@ -554,12 +514,16 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env):
554514 self .hidden_size = hidden_size
555515
556516 LinearLayer = get_quantized_linear_layer (env .quant_config )
517+ linear_kwargs = {}
518+ if LinearLayer != torch .nn .Linear :
519+ linear_kwargs = {"quant_config" : env .quant_config }
557520
558521 self .wo = LinearLayer (
559522 n_heads * self .head_dim ,
560523 hidden_size ,
561524 bias = False ,
562525 device = device ,
526+ ** linear_kwargs ,
563527 )
564528
565529 Kernel = (
@@ -578,25 +542,29 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env):
578542 (n_heads + 2 * self .n_kv_heads ) * self .head_dim ,
579543 bias = False ,
580544 device = device ,
545+ ** linear_kwargs ,
581546 )
582547 else :
583548 self .wq = LinearLayer (
584549 hidden_size ,
585550 n_heads * self .head_dim ,
586551 bias = False ,
587552 device = device ,
553+ ** linear_kwargs ,
588554 )
589555 self .wk = LinearLayer (
590556 hidden_size ,
591557 self .n_kv_heads * self .head_dim ,
592558 bias = False ,
593559 device = device ,
560+ ** linear_kwargs ,
594561 )
595562 self .wv = LinearLayer (
596563 hidden_size ,
597564 self .n_kv_heads * self .head_dim ,
598565 bias = False ,
599566 device = device ,
567+ ** linear_kwargs ,
600568 )
601569
602570 def load_hook (self , state_dict , prefix , * args ):
0 commit comments