In [1]:
# Path to HF model

from pathlib import Path
MODEL_ID = "Zyphra_Zamba2-7B-instruct"
MODEL_PATH=Path(f"~/models/{MODEL_ID}").expanduser()

In [2]:
# Mamba layer was modified to get dtype based on conv1d rather than inputs. 
import modeling_zamba2
import configuration_zamba2


  return torch.library.impl_abstract(f"{name}")(func)
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd


In [3]:
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear, HQQBackend
from hqq.core.optimize import optimize_weights_proximal
from hqq.core.quantize import Quantizer
# This supposed to produce better results
Quantizer.optimize_weights = optimize_weights_proximal


In [4]:
USE_AOINT4 = True

# we will define three level of quants to use later: 
BASIC_QUANT={'nbits': 4, 'group_size': 64}
#BETTER_QUANT={'nbits': 8, 'group_size': 64} -- no aoint4
BETTER_QUANT={'nbits': 4, 'group_size': 32}
#BETTER_QUANT=BASIC_QUANT
NO_QUANT = None
AXIS = (1) if USE_AOINT4 else (0)

print(f"AO4: {USE_AOINT4}")

AO4: True


In [5]:
import torch
import transformers
import time
import torch.nn as nn

from copy import deepcopy
from typing import Optional, Union, Literal

In [6]:
model = modeling_zamba2.Zamba2ForCausalLM.from_pretrained(
    MODEL_PATH,
    device_map="cuda", 
    torch_dtype=torch.bfloat16
)
model

Zamba2ForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Zamba2ForCausalLM(
  (model): Zamba2Model(
    (embed_tokens): Embedding(32000, 3584, padding_idx=0)
    (blocks): ModuleList(
      (0-1): 2 x Zamba2AttentionDecoderLayer(
        (self_attn): Zamba2SdpaAttention(
          (q_proj): Linear(in_features=7168, out_features=7168, bias=False)
          (k_proj): Linear(in_features=7168, out_features=7168, bias=False)
          (v_proj): Linear(in_features=7168, out_features=7168, bias=False)
          (o_proj): Linear(in_features=7168, out_features=3584, bias=False)
          (rotary_emb): Zamba2RotaryEmbedding()
        )
        (feed_forward): Zamba2MLP(
          (linear_fc1): Linear(in_features=3584, out_features=28672, bias=False)
          (linear_fc2): Linear(in_features=14336, out_features=3584, bias=False)
          (linear_fc1_lora_A_list): ParameterList(
              (0): Object of type: Linear
              (1): Object of type: Linear
              (2): Object of type: Linear
              (3): Object of type: Linear
       

In [7]:
def model_numel(m: nn.Module):
    return sum(p.numel() for p in m.parameters())   

def model_sz(m: nn.Module):
    return sum(p.numel() * p.dtype.itemsize for p in m.parameters())   

In [8]:
def model_parm_count(count_fn=model_sz):
    # NOTE: doesn't work with hqq
    print(f"TOTAL PARMS: {model_sz(model):_}")
    print(f"TRANS PARMS: {count_fn(model.model.blocks):_}")
    print(f"LINEAR PARMS: {count_fn(model.model.linear_layers):_}")
    print(f"TOTAL MAMBA: {count_fn(model.model.mamba_layers):_}")
    print(f"=== SINGLE MAMBA === x {len(model.model.mamba_layers)}")
    m = model.model.mamba_layers[0].mamba
    print(f"TOTAL: {count_fn(m):_}")
    print(f"IN: {count_fn(m.in_proj):_}")
    print(f"CONV: {count_fn(m.conv1d):_}")
    print(f"OUT: {count_fn(m.out_proj):_}")
model_parm_count()

TOTAL PARMS: 14_820_847_264
TRANS PARMS: 1_550_624_768
LINEAR PARMS: 333_971_456
TOTAL MAMBA: 12_706_867_872
=== SINGLE MAMBA === x 81
TOTAL: 156_867_744
IN: 105_398_272
CONV: 74_240
OUT: 51_380_224


In [9]:
# actual function to convert nn.layer to hqq
def do_quant(nn_linear, quant_config: Optional[BaseQuantizeConfig]):
    if not quant_config:
        return nn_linear    
    assert isinstance(nn_linear, nn.Linear), f"expected nn.Linear, got {type(nn_linear)}"
    hqq_layer = HQQLinear(nn_linear, 
                      quant_config=BaseQuantizeConfig(**quant_config, axis=AXIS),
                      compute_dtype=torch.bfloat16,
                      device='cuda', 
                      initialize=True, 
                      del_orig=True) 
    return hqq_layer

#  Actual quantization

## Block self attention

Main points of interest: q_proj, k_proj, v_proj, o_proj.

It also has all loras, but I didn't want to touch them

In [10]:
qkvo_quant = NO_QUANT
for block in model.model.blocks:
    block = block.self_attn    
    block.q_proj = do_quant(block.q_proj, qkvo_quant)
    block.k_proj = do_quant(block.k_proj, qkvo_quant)
    block.v_proj = do_quant(block.v_proj, qkvo_quant)
    block.o_proj = do_quant(block.o_proj, qkvo_quant)

## MLP

It has shared linear_fc1(up+gate), linear_fc2(down), lora.
First two are the main points of the interest

In [11]:
# Without moving some memory to CPU we'll hit OoM
model.model.mamba_layers.cpu()

ModuleList(
  (0-80): 81 x Zamba2MambaDecoderLayer(
    (mamba): Mamba2Layer(
      (in_proj): ModuleList(
        (0): Linear(in_features=3584, out_features=14704, bias=False)
      )
      (conv1d): Conv1d(7424, 7424, kernel_size=(4,), stride=(1,), padding=(3,), groups=7424)
      (act): SiLU()
      (norm): RMSNorm()
      (out_proj): Linear(in_features=7168, out_features=3584, bias=False)
    )
    (input_layernorm): Zamba2RMSNorm()
  )
)

In [12]:
mlp_quant = NO_QUANT
for block in model.model.blocks:
    block = block.feed_forward
    block.linear_fc1 = do_quant(block.linear_fc1, mlp_quant)
    block.linear_fc2 = do_quant(block.linear_fc2, mlp_quant)

In [13]:
model.model.mamba_layers.cuda()

ModuleList(
  (0-80): 81 x Zamba2MambaDecoderLayer(
    (mamba): Mamba2Layer(
      (in_proj): ModuleList(
        (0): Linear(in_features=3584, out_features=14704, bias=False)
      )
      (conv1d): Conv1d(7424, 7424, kernel_size=(4,), stride=(1,), padding=(3,), groups=7424)
      (act): SiLU()
      (norm): RMSNorm()
      (out_proj): Linear(in_features=7168, out_features=3584, bias=False)
    )
    (input_layernorm): Zamba2RMSNorm()
  )
)

## Additional layers

Zamba has additional linear layers for transformers

In [14]:
lin_quant = NO_QUANT

block = model.model.linear_layers
for i in range(len(block)):
    block[i] = do_quant(block[i], lin_quant)

## Mamba.

Mamba has two linear blocks: in_proj, out_proj.
It also has conv1d, but quanting conv it would requrie switching to even less efficient path course and there are next to no parms in conv, while in_proj consumes the majority of space

In [15]:
block = model.model.mamba_layers

def nth(n): return block[n].mamba

GENERAL_QUANT = BETTER_QUANT
i = 0
for i in range(i, len(block)):
    nth(i).in_proj[0] = do_quant(nth(i).in_proj[0], GENERAL_QUANT)
    nth(i).out_proj = do_quant(nth(i).out_proj, GENERAL_QUANT)
    i += 1
    
assert i == len(block)

In [16]:
if USE_AOINT4:
    from hqq.utils.patching import prepare_for_inference
    prepare_for_inference(model, backend="torchao_int4") 
else:
    HQQLinear.set_backend(HQQBackend.ATEN_FORWARD)



# Testing

In [17]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

In [18]:
# modeling_zamba2 cache implementation doesn't work with non-forked vesrion
# as it doesn't inherit from Cache, which breaks the generation.
class FixedHybraCache(modeling_zamba2.HybridMambaAttentionDynamicCache, transformers.Cache):
    ...


In [19]:
#prompt = "What factors contributed to the fall of the Roman Empire?"
#prompt = "Give me a list of good god-tier reasons why kittehs are cuter than doggos. Make it verbose: several sentences per item."
prompt = "Write a long 1000 words story about dreams"


In [20]:
def run_model(prompt, max_new_tokens=150, is_instruct=True):
    if is_instruct:
        sample = [{'role': 'user', 'content': prompt}]
        prompt = tokenizer.apply_chat_template(sample, tokenize=False)
    print(prompt)    
    input_ids = tokenizer(prompt, return_tensors='pt', add_special_tokens=False).to("cuda")
    start_time = time.time()    
    cache=FixedHybraCache(model.config, 1)
    cache_exist = cache is not None
    outputs = model.generate(
        **input_ids, 
        max_new_tokens=1000, #1000
        return_dict_in_generate=False, 
        output_scores=False,     
        num_beams=1, 
        use_cache=cache is not None,
        past_key_values=cache,
        do_sample=False)
    delta_time = time.time() - start_time
    tps = outputs.numel() / delta_time
    print(f"*** TIME: {delta_time}, TPS:{tps} {cache_exist=}, |T|:{outputs.numel()}")
    print((tokenizer.decode(outputs[0])))

run_model(prompt)

# ATEN: *** TIME: 305.0571117401123, TPS:3.2485720275364827 cache_exist=True, |T|:991
# AO4:  *** TIME: 60.70746970176697, TPS:16.768941367529436 cache_exist=True, |T|:1018
# RAW:  *** TIME: 57.77293086051941, TPS:15.560920773959806 cache_exist=True, |T|:899

<|im_start|>user
Write a long 1000 words story about dreams<|im_end|>

*** TIME: 65.19050335884094, TPS:15.615771432173515 cache_exist=True, |T|:1018
<|im_start|> user
Write a long 1000 words story about dreams<|im_end|> 
<|im_start|> assistant
Once upon a time, in a small village nestled between rolling hills and lush forests, there lived a young girl named Lila. Lila was a dreamer, always lost in her own world of imagination and wonder. She had a special gift - the ability to enter the realm of dreams.

Every night, as she drifted off to sleep, Lila would find herself transported to a magical land filled with vibrant colors, enchanting creatures, and breathtaking landscapes. In this dream world, she could fly through the skies, explore hidden caves, and converse with wise old sages. It was a place where anything was possible, and Lila reveled in the freedom and joy it brought her.

One evening, as Lila lay in bed, she felt a strange sensation wash over her. The dream world seemed dif

In [21]:
model

Zamba2ForCausalLM(
  (model): Zamba2Model(
    (embed_tokens): Embedding(32000, 3584, padding_idx=0)
    (blocks): ModuleList(
      (0-1): 2 x Zamba2AttentionDecoderLayer(
        (self_attn): Zamba2SdpaAttention(
          (q_proj): Linear(in_features=7168, out_features=7168, bias=False)
          (k_proj): Linear(in_features=7168, out_features=7168, bias=False)
          (v_proj): Linear(in_features=7168, out_features=7168, bias=False)
          (o_proj): Linear(in_features=7168, out_features=3584, bias=False)
          (rotary_emb): Zamba2RotaryEmbedding()
        )
        (feed_forward): Zamba2MLP(
          (linear_fc1): Linear(in_features=3584, out_features=28672, bias=False)
          (linear_fc2): Linear(in_features=14336, out_features=3584, bias=False)
          (linear_fc1_lora_A_list): ParameterList(
              (0): Object of type: Linear
              (1): Object of type: Linear
              (2): Object of type: Linear
              (3): Object of type: Linear
       