In [None]:
import sys

In [None]:
sys.path.append("../../scripts/")

In [None]:
import torch
import transformers
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.utils import hub, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
from accelerate import init_empty_weights
import safetensors

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from transformers.models.qwen2.modeling_qwen2 import Qwen2FlashAttention2, Qwen2SdpaAttention

In [None]:
from hqq.core.quantize import HQQLinear, HQQBackend, BaseQuantizeConfig
from quant_utils import replace_linear

In [None]:
transformers.__file__, transformers.__version__

('/home/k/git/transformers/src/transformers/__init__.py', '4.46.0.dev0')

In [None]:
# # KV fp8 quantization.
# self.k_scale = config.get("k_scale", None)
# self.v_scale = config.get("v_scale", None)

# # Cross Layer Attention (CLA).
# self.compute_new_kv = config.get("compute_new_kv", True)

In [None]:
m = torch.nn.Linear(128,64,bias=True)

In [None]:
m.bias.shape

torch.Size([64])

In [None]:
model_name = "Qwen/Qwen2.5-32B-Instruct"

In [None]:
cfg = AutoConfig.from_pretrained(model_name)
cfg._attn_implementation = "eager"
cfg.num_hidden_layers = 4
cfg.hidden_size //= 8
cfg.intermediate_size //= 8
cfg.num_attention_heads //= 2
cfg.num_key_value_heads //= 2

In [None]:
cfg

Qwen2Config {
  "_name_or_path": "Qwen/Qwen2.5-32B-Instruct",
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151645,
  "hidden_act": "silu",
  "hidden_size": 640,
  "initializer_range": 0.02,
  "intermediate_size": 3456,
  "max_position_embeddings": 32768,
  "max_window_layers": 70,
  "model_type": "qwen2",
  "num_attention_heads": 20,
  "num_hidden_layers": 4,
  "num_key_value_heads": 4,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 1000000.0,
  "sliding_window": null,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.0.dev0",
  "use_cache": true,
  "use_sliding_window": false,
  "vocab_size": 152064
}

In [None]:
with init_empty_weights(): model = AutoModelForCausalLM.from_config(cfg)

In [None]:
model

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 640)
    (layers): ModuleList(
      (0-3): 4 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=640, out_features=640, bias=True)
          (k_proj): Linear(in_features=640, out_features=128, bias=True)
          (v_proj): Linear(in_features=640, out_features=128, bias=True)
          (o_proj): Linear(in_features=640, out_features=640, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=640, out_features=3456, bias=False)
          (up_proj): Linear(in_features=640, out_features=3456, bias=False)
          (down_proj): Linear(in_features=3456, out_features=640, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((640,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((640,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((640,),

In [None]:
cfg._attn_implementation = "sdpa"
cfg.use_fp8_kv_scale = True
cfg.cla_kv_cache_map = {0:0, 1:1, 2:1, 3:0}
cfg.palu_kv_compression_enabled = False

In [None]:
with init_empty_weights(): model = AutoModelForCausalLM.from_config(cfg); model

KV fp8 quantization is enabled.
Cross Layer Attention (CLA) is enabled.


In [None]:
for n,p in model.model.layers[0].self_attn.named_parameters(): print(n, p.shape)

k_scale torch.Size([])
v_scale torch.Size([])
q_proj.weight torch.Size([640, 640])
q_proj.bias torch.Size([640])
k_proj.weight torch.Size([128, 640])
k_proj.bias torch.Size([128])
v_proj.weight torch.Size([128, 640])
v_proj.bias torch.Size([128])
o_proj.weight torch.Size([640, 640])


In [None]:
cfg._attn_implementation = "eager"
cfg.use_fp8_kv_scale = True
cfg.cla_kv_cache_map = {0:0, 1:1, 2:1, 3:0}
cfg.palu_kv_compression_enabled = False
cfg.use_cache = False
cfg.debug_kv_sharing = True
cfg.output_attentions = True

model = AutoModelForCausalLM.from_config(cfg)
model.to(device="cuda", dtype=torch.bfloat16);



In [None]:
# model.model.layers[0].self_attn.k_scale, model.model.layers[0].self_attn.v_scale

In [None]:
model.config.use_fp8_kv_scale, model.config.cla_kv_cache_map, model.config.use_cache

(True, {0: 0, 1: 1, 2: 1, 3: 0}, False)

In [None]:
# for n,p in model.model.layers[0].self_attn.named_parameters(): print(n, p.shape, p.dtype)

In [None]:
x = torch.arange(32, device="cuda").view(1,-1); x

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]],
       device='cuda:0')

In [None]:
out_eager = model(x)

In [None]:
model_state_dict = model.state_dict()

In [None]:
if cfg._attn_implementation == "eager":
    assert torch.equal(out_eager.attentions[0], out_eager.attentions[3])
    assert torch.equal(out_eager.attentions[1], out_eager.attentions[2])
    assert not torch.equal(out_eager.attentions[0], out_eager.attentions[1])
    
attn_outputs = [l.self_attn.debug_cla_attn_output for l in model.model.layers]
assert len(attn_outputs) == 4
assert torch.equal(attn_outputs[0], attn_outputs[3])
assert torch.equal(attn_outputs[1], attn_outputs[2])
assert not torch.equal(attn_outputs[0], attn_outputs[1])

In [None]:
cfg._attn_implementation = "sdpa"
cfg.output_attentions = False
model = AutoModelForCausalLM.from_config(cfg)

In [None]:
model.to(device="cuda", dtype=torch.bfloat16);
model.load_state_dict(model_state_dict);

In [None]:
x

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]],
       device='cuda:0')

In [None]:
out_sdpa = model(x)

In [None]:
out_eager.logits, out_sdpa.logits

(tensor([[[ 6.0156e-01, -3.9551e-02,  3.9648e-01,  ...,  2.4316e-01,
            2.6172e-01,  3.6328e-01],
          [ 8.2031e-01, -9.2773e-02,  7.2754e-02,  ...,  9.3359e-01,
           -1.4551e-01,  4.7852e-01],
          [ 5.2002e-02, -1.1475e-01,  2.6953e-01,  ...,  7.3047e-01,
           -1.2061e-01,  2.8516e-01],
          ...,
          [-6.1512e-05, -2.0898e-01,  8.9062e-01,  ...,  2.9883e-01,
           -2.5195e-01, -6.5234e-01],
          [ 3.9307e-02, -3.9453e-01,  6.5625e-01,  ...,  6.1719e-01,
           -5.0000e-01, -9.2188e-01],
          [-1.2305e-01,  1.9434e-01,  4.4922e-01,  ...,  6.0938e-01,
           -1.9434e-01, -1.0078e+00]]], device='cuda:0', dtype=torch.bfloat16,
        grad_fn=<UnsafeViewBackward0>),
 tensor([[[ 0.6016, -0.0396,  0.3965,  ...,  0.2432,  0.2617,  0.3633],
          [ 0.8242, -0.0962,  0.0830,  ...,  0.9297, -0.1270,  0.4883],
          [ 0.0569, -0.1094,  0.2754,  ...,  0.7305, -0.1099,  0.2988],
          ...,
          [-0.0019, -0.2090,  0

In [None]:
attn_outputs = [l.self_attn.debug_cla_attn_output for l in model.model.layers]
assert len(attn_outputs) == 4
assert torch.equal(attn_outputs[0], attn_outputs[3])
assert torch.equal(attn_outputs[1], attn_outputs[2])
assert not torch.equal(attn_outputs[0], attn_outputs[1])

In [None]:
assert torch.isclose(out_eager.logits, out_sdpa.logits, rtol=0.1, atol=0.1).float().mean().item() > 0.99

In [None]:
del model, attn_outputs
torch.cuda.empty_cache()

In [None]:
cfg._attn_implementation = "sdpa"
cfg.use_fp8_kv_scale = False
cfg.cla_kv_cache_map = {0:0, 1:1, 2:2, 3:3}
cfg.palu_kv_compression_enabled = False
cfg.use_cache = False
cfg.debug_kv_sharing = False
cfg.output_attentions = False

model = AutoModelForCausalLM.from_config(cfg)
model.to(device="cuda", dtype=torch.bfloat16)
model_state_dict = model.state_dict()
test_output = model(x)

cfg.cla_kv_cache_map = None
model = AutoModelForCausalLM.from_config(cfg)
model.to(device="cuda", dtype=torch.bfloat16)
model.load_state_dict(model_state_dict);
base_output = model(x)

assert torch.equal(test_output.logits, base_output.logits)

In [None]:
def fp8_quant_dequant(x, scale):
    # dtype -> fp8 / fp8 -> dtype
    assert not x.isnan().any(), "key or value states contain NaN before fp8 quantization"
    x_fp8 = (x.to(torch.float32)/scale).to(torch.float8_e4m3fn)
    assert not x_fp8.isnan().any(), "key or value states contain NaN after fp8 quantization"
    x_dequant = (x_fp8.to(torch.float32)*scale).to(x.dtype)
    assert not x_dequant.isnan().any(), "key or value states contain NaN after fp8 dequantization"
    return x_dequant

In [None]:
%%ai

do you have any suggestion for fixing potential NaN values inside `fp8_quant_dequant`. Consider min and max values of `torch.float8_e4m3fn` dtype so that values don't underflow or overflow.

##### Assistant

To address potential NaN values in `fp8_quant_dequant`, we should consider the range limitations of `torch.float8_e4m3fn`. Here's an improved version:

In [None]:
def fp8_quant_dequant(x, scale):
    assert not x.isnan().any(), "key or value states contain NaN before fp8 quantization"
    
    # Get min/max values for float8_e4m3fn
    f8_min, f8_max = torch.finfo(torch.float8_e4m3fn).min, torch.finfo(torch.float8_e4m3fn).max
    
    # Clamp x/scale to float8 range before conversion
    x_fp32 = x.to(torch.float32) / scale
    x_fp32_clamped = torch.clamp(x_fp32, f8_min, f8_max)
    
    x_fp8 = x_fp32_clamped.to(torch.float8_e4m3fn)
    assert not x_fp8.isnan().any(), "key or value states contain NaN after fp8 quantization"
    
    x_dequant = (x_fp8.to(torch.float32) * scale).to(x.dtype)
    assert not x_dequant.isnan().any(), "key or value states contain NaN after fp8 dequantization"
    
    return x_dequant

This implementation:

1. Converts `x` to float32 and divides by `scale`.
2. Clamps values to the float8_e4m3fn range before conversion.
3. Converts to float8_e4m3fn, then back to the original dtype.

If NaNs still occur, you may need to adjust `scale` to ensure values fall within the float8 range.

##### User

In [None]:
f8_min, f8_max = torch.finfo(torch.float8_e4m3fn).min, torch.finfo(torch.float8_e4m3fn).max

In [None]:
f8_min, f8_max

(-448.0, 448.0)

In [None]:
assert len(model.model.cla_key_value) == 2
for k,v in model.model.cla_key_value:
    assert k.shape == v.shape
    break

In [None]:
model.model.cla_key_value[0][0].isnan().any()

tensor(True, device='cuda:0')

In [None]:
def fp8_quant_dequant(x, scale):
    # dtype -> fp8 / fp8 -> dtype
    x_fp8 = (x.to(torch.float32)/scale).to(torch.float8_e4m3fn)
    x_dequant = (x_fp8.to(torch.float32)*scale).to(x.dtype)
    return x_dequant

In [None]:
k

tensor([[[[-2.0800, -0.1200,  0.0400,  ..., -3.2000,  0.1400,  0.2000],
          [-0.8800, -1.6000, -1.6000,  ..., -0.8000,  0.8000, -0.0700],
          [-1.2800, -1.2000,  0.0700,  ..., -0.1500,  1.9200, -2.4000],
          [-0.2000,  0.8000,  0.2400,  ...,  0.4800, -2.8800, -1.1200],
          [ 0.4800, -0.1800, -1.1200,  ...,  0.8000,  0.6000,  1.1200]],

         [[-1.7600, -1.7600,  1.7600,  ..., -0.5200,  0.1800, -1.4400],
          [-1.2000, -0.6400, -0.1400,  ..., -0.2400, -1.7600,  0.4800],
          [-1.9200,  0.5200,  0.1200,  ..., -0.9600, -1.9200,  3.8400],
          [-1.2800,  0.3000,  0.0375,  ..., -0.1500,  0.5600, -0.4400],
          [-1.6000,  2.5600, -0.6400,  ...,  3.2000, -0.9600, -1.6000]],

         [[-2.5600, -0.7200, -1.2000,  ...,  1.6000,  1.7600,  1.0400],
          [-1.9200, -1.1200,  2.5600,  ...,  0.4400,  3.2000,  0.8000],
          [ 0.4800, -0.4000, -0.7200,  ...,  1.2000, -0.5600,  2.0800],
          [ 0.8000, -0.0300,  0.0225,  ...,  0.2600,  1.2000

In [None]:
fp8_quant_dequant(k, model.model.layers[0].self_attn.k_scale)

tensor([[[[-2.0800, -0.1200,  0.0400,  ..., -3.2000,  0.1400,  0.2000],
          [-0.8800, -1.6000, -1.6000,  ..., -0.8000,  0.8000, -0.0700],
          [-1.2800, -1.2000,  0.0700,  ..., -0.1500,  1.9200, -2.4000],
          [-0.2000,  0.8000,  0.2400,  ...,  0.4800, -2.8800, -1.1200],
          [ 0.4800, -0.1800, -1.1200,  ...,  0.8000,  0.6000,  1.1200]],

         [[-1.7600, -1.7600,  1.7600,  ..., -0.5200,  0.1800, -1.4400],
          [-1.2000, -0.6400, -0.1400,  ..., -0.2400, -1.7600,  0.4800],
          [-1.9200,  0.5200,  0.1200,  ..., -0.9600, -1.9200,  3.8400],
          [-1.2800,  0.3000,  0.0375,  ..., -0.1500,  0.5600, -0.4400],
          [-1.6000,  2.5600, -0.6400,  ...,  3.2000, -0.9600, -1.6000]],

         [[-2.5600, -0.7200, -1.2000,  ...,  1.6000,  1.7600,  1.0400],
          [-1.9200, -1.1200,  2.5600,  ...,  0.4400,  3.2000,  0.8000],
          [ 0.4800, -0.4000, -0.7200,  ...,  1.2000, -0.5600,  2.0800],
          [ 0.8000, -0.0300,  0.0225,  ...,  0.2600,  1.2000

In [None]:
k.dtype

torch.float32

### bias load and quant

In [None]:
quant_config_4bit = BaseQuantizeConfig(nbits=4, group_size=128, quant_zero=False,
                                        quant_scale=False, offload_meta=False, view_as_float=True, axis=1)
quant_config_2bit = BaseQuantizeConfig(nbits=2, group_size=32, quant_zero=False,
                                        quant_scale=False, offload_meta=False, view_as_float=True, axis=1)
attn_layers = ["q_proj", "k_proj", "v_proj", "o_proj"]
mlp_layers  = ["gate_proj", "up_proj", "down_proj"]
layers_4bit = attn_layers + mlp_layers
layers_2bit = []
skip_modules = ["lm_head"]
block_influence_patterns = []
rank = 0
compute_dtype = torch.bfloat16
model.model = replace_linear(model=model.model, 
                                linear_replacement=HQQLinear, 
                                quant_config_4bit=quant_config_4bit, 
                                quant_config_2bit=quant_config_2bit,
                                layers_4bit=layers_4bit, 
                                layers_2bit=layers_2bit,
                                skip_modules=skip_modules,
                                block_influence_patterns=block_influence_patterns,
                                prefix='model',
                                device=rank,
                                compute_dtype=compute_dtype, 
                                del_orig=True, 
                                initialize=False)

model.layers
[]
model.layers.0
[]
model.layers.0.self_attn
[]
Replacing model.layers.0.self_attn.q_proj with <class 'hqq.core.quantize.HQQLinear'> with 4-bit groupsize 128
Replacing model.layers.0.self_attn.k_proj with <class 'hqq.core.quantize.HQQLinear'> with 4-bit groupsize 128
Replacing model.layers.0.self_attn.v_proj with <class 'hqq.core.quantize.HQQLinear'> with 4-bit groupsize 128
Replacing model.layers.0.self_attn.o_proj with <class 'hqq.core.quantize.HQQLinear'> with 4-bit groupsize 128
model.layers.0.mlp
[]
Replacing model.layers.0.mlp.gate_proj with <class 'hqq.core.quantize.HQQLinear'> with 4-bit groupsize 128
Replacing model.layers.0.mlp.up_proj with <class 'hqq.core.quantize.HQQLinear'> with 4-bit groupsize 128
Replacing model.layers.0.mlp.down_proj with <class 'hqq.core.quantize.HQQLinear'> with 4-bit groupsize 128
model.layers.1
[]
model.layers.1.self_attn
[]
Replacing model.layers.1.self_attn.q_proj with <class 'hqq.core.quantize.HQQLinear'> with 4-bit groupsize 128
R

In [None]:
model

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 5120)
    (layers): ModuleList(
      (0-3): 4 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): HQQLinear(
            (linear_layer): Linear(in_features=5120, out_features=5120, bias=True)
          )
          (k_proj): HQQLinear(
            (linear_layer): Linear(in_features=5120, out_features=1024, bias=True)
          )
          (v_proj): HQQLinear(
            (linear_layer): Linear(in_features=5120, out_features=1024, bias=True)
          )
          (o_proj): HQQLinear(
            (linear_layer): Linear(in_features=5120, out_features=5120, bias=False)
          )
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): HQQLinear(
            (linear_layer): Linear(in_features=5120, out_features=27648, bias=False)
          )
          (up_proj): HQQLinear(
            (linear_layer): Linear(in_features=5120, out_featu

In [None]:
idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(model_name, idx)

In [None]:
weights = safetensors.torch.load_file(files[1])

In [None]:
weights_copy = {}
for name, param in iter(weights.items()):
    bias_param = None
    if name.endswith(".bias"): continue
    if name.endswith(".weight"):
        bias_name = name.replace(".weight", ".bias")
        if bias_name in weights:
            bias_param = weights[bias_name]
    weights_copy[name] = (param, bias_param)

In [None]:
for name, (param, bias_param) in weights_copy.items():
    print(name, param.shape, bias_param.shape if bias_param is not None else None)

model.layers.2.mlp.down_proj.weight torch.Size([5120, 27648]) None
model.layers.2.mlp.gate_proj.weight torch.Size([27648, 5120]) None
model.layers.3.input_layernorm.weight torch.Size([5120]) None
model.layers.3.mlp.down_proj.weight torch.Size([5120, 27648]) None
model.layers.3.mlp.gate_proj.weight torch.Size([27648, 5120]) None
model.layers.3.mlp.up_proj.weight torch.Size([27648, 5120]) None
model.layers.3.post_attention_layernorm.weight torch.Size([5120]) None
model.layers.3.self_attn.k_proj.weight torch.Size([1024, 5120]) torch.Size([1024])
model.layers.3.self_attn.o_proj.weight torch.Size([5120, 5120]) None
model.layers.3.self_attn.q_proj.weight torch.Size([5120, 5120]) torch.Size([5120])
model.layers.3.self_attn.v_proj.weight torch.Size([1024, 5120]) torch.Size([1024])
model.layers.4.input_layernorm.weight torch.Size([5120]) None
model.layers.4.mlp.down_proj.weight torch.Size([5120, 27648]) None
model.layers.4.mlp.gate_proj.weight torch.Size([27648, 5120]) None
model.layers.4.mlp.u

In [None]:
torch.nn.Linear(128,64,bias=True).get_parameter("bias")

Parameter containing:
tensor([ 0.0277,  0.0694,  0.0460,  0.0540,  0.0410,  0.0492,  0.0447, -0.0036,
        -0.0181,  0.0537,  0.0414,  0.0143,  0.0643, -0.0863, -0.0174,  0.0551,
        -0.0195, -0.0147,  0.0242, -0.0850,  0.0644, -0.0038, -0.0637, -0.0624,
         0.0111,  0.0059, -0.0147, -0.0552,  0.0768, -0.0064, -0.0505,  0.0064,
        -0.0104, -0.0583, -0.0130,  0.0077,  0.0257,  0.0842,  0.0632,  0.0396,
        -0.0588, -0.0166,  0.0791, -0.0845, -0.0719, -0.0801, -0.0023, -0.0765,
         0.0230, -0.0268, -0.0444,  0.0556,  0.0225, -0.0201, -0.0292,  0.0430,
        -0.0149,  0.0852, -0.0632, -0.0257,  0.0270,  0.0235,  0.0110, -0.0479],
       requires_grad=True)

In [None]:
torch.nn.Linear(128,64,bias=None).get_parameter("bias")

AttributeError: `bias` is not an nn.Parameter

In [None]:
type(torch.nn.Linear(128,64,bias=None).get_parameter("weight"))(torch.tensor([1.,2.,4.,8.]))

Parameter containing:
tensor([1., 2., 4., 8.], requires_grad=True)