In [1]:
from pathlib import Path
import torch
import torch.nn as nn
# Quantizing
from hqq.core.quantize import *
from typing import Optional
from tqdm.auto import tqdm
import re
# Reconstructing 
from transformers import AutoConfig, AutoModelForCausalLM

[36mhqq_aten package available. Set backend to HQQBackend.ATEN for faster inference and HQQBackend.ATEN_BACKPROP for faster training![0m


In [2]:
out_dir = Path("hqqfied")
out_dir.mkdir(exist_ok=True)
bin_files = list(Path(".").glob("pytorch*.bin"))

## Quantization

In [38]:
zero_scale_group_size                              = 128
attn_parms    = BaseQuantizeConfig(nbits=4, group_size=64, offload_meta=True) 
experts_parms = BaseQuantizeConfig(nbits=2, group_size=8, offload_meta=True)

attn_parms['scale_quant_params']['group_size']    = zero_scale_group_size
attn_parms['zero_quant_params']['group_size']     = zero_scale_group_size
experts_parms['scale_quant_params']['group_size'] = zero_scale_group_size
experts_parms['zero_quant_params']['group_size']  = zero_scale_group_size


[33mquant_zero and quant_scale must be the same when offload_meta is set to True. Setting quant_scale=quant_zero.[0m
[33mquant_zero and quant_scale must be the same when offload_meta is set to True. Setting quant_scale=quant_zero.[0m


In [53]:

class MixtralQuantizeConfig:
    RE_REQUIRED_PREFIX = re.compile(r"model\.layers.(\d+)\.")
    RE_MOE_WEIGHT = re.compile(r"block_sparse_moe\.experts\.\d+\.w\d+\.weight")
    RE_ATTN_WEIGHT = re.compile(r"self_attn\.[qkvo]_proj\.weight")
    RE_NO_QUANT = re.compile(r"""
        block_sparse_moe\.gate.weight
    |   input_layernorm.weight
    |   post_attention_layernorm.weight
    """, re.VERBOSE)

    def __init__(self, attn_parms: dict, experts_parms: dict) -> None:
        self.attn_parms = attn_parms
        self.experts_parms = experts_parms

    def quantconfig_from_tensor_name(self, full_name: str) -> Optional[dict]:
        match = self.RE_REQUIRED_PREFIX.match(full_name)
        if not match: return None
        name = full_name[match.span()[1]:]
        if self.RE_NO_QUANT.match(name): return None
        if self.RE_MOE_WEIGHT.match(name): return self.experts_parms
        if self.RE_ATTN_WEIGHT.match(name): return self.attn_parms
        raise ValueError(f"Don't know what to do with {full_name} (key={name})")

    @property
    def compute_dtype(self):
        return torch.bfloat16
    
    def tensor_name_to_layer_idx(self, tensor_name: str):
        match = self.RE_REQUIRED_PREFIX.match(tensor_name)
        if match is None:
            return None
        [start, end] = match.span(1)
        idx = int(tensor_name[start:end])
        return idx

        

In [55]:
quant_cfg = MixtralQuantizeConfig(attn_parms=attn_parms, experts_parms=experts_parms)

In [None]:
import gc
def quantize_file(p: Path, config_getter, compute_dtype, is_secondary_tqdm=True):
    raw_data = torch.load(p)
    # We need to store mappings from original to quantized
    quantized_state  = {}

    for (k, v) in (bar := tqdm(raw_data.items(), leave=not is_secondary_tqdm)):
        raw_data[k] = None # Loaded data no longer needed
        if not (cfg := config_getter(k)):
            quantized_state[k] = v
            continue
        bar.set_description(k)
        
        # at this point K contains name of nn.Linear.weight
        # since Mixtral doesn't use biases we can reconstruct Linear completely
        weight: torch.Tensor = raw_data[k]
        out_features, in_features = weight.shape # nn.Linears weights are transposed compared to X@W, so out is shape[0]
        linear = nn.Linear(in_features, out_features, dtype=weight.dtype, bias=False)
        linear.weight.data = weight
        
        # Now we can quantize it.
        # No need for del_orig as even if dels it, we have `linear` ourself, so HQQLinear can't delete much
        quantized = HQQLinear(linear, cfg, compute_dtype=compute_dtype, del_orig=False) 
        quantized_state[k] = quantized.state_dict()
        del linear
        gc.collect() # Not sure it's needed, but probably wouldn't hurt either
        

    torch.save(quantized_state, out_dir/p.name)
    del quantized_state
    # Aggressive cleanup. Not sure how redundant it is, but it reduced memory usage to ~20G from ~27G when I tried 3 layers only
    gc.collect()
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.empty_cache()


In [None]:
for p in tqdm(bin_files):
    quantize_file(p, quant_cfg.quantconfig_from_tensor_name, torch.bfloat16)

## Reconstrction

In [36]:
# NB. I restarted notebook from this cell to make sure memory is not used for anything else or selecteviely run previous
# cells to have quant_cfg

# First we need to build a dummy model. 
dummy_config = AutoConfig.from_pretrained(".")
layers = dummy_config.num_hidden_layers
# We force it to have 1 layer to not waste RAM
dummy_config.num_hidden_layers = 1
model = AutoModelForCausalLM.from_config(dummy_config, torch_dtype=quant_cfg.compute_dtype)

In [4]:
quantized_files = list(out_dir.glob("*.bin"))
quantized_files.sort() # We need them to be sorted to iterate layer by layer

In [88]:
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer

In [92]:
def ensure_layer_exists(model: nn.Module, tensor_name: str):
    tensor_name_parts = tensor_name.split('.')
    # Find a layer or append
    layer_name = '.'.join(tensor_name_parts[:-1])
    layer_idx = quant_cfg.tensor_name_to_layer_idx(tensor_name)
    if layer_idx is None:
        return model.get_submodule(layer_name)
    # We add one layer at the time
    assert layer_idx <= len(model.model.layers), "layer index must be within existing layers or 1 above"
    if layer_idx == len(model.model.layers):
        new_layer = MixtralDecoderLayer(model.config, layer_idx).to(quant_cfg.compute_dtype)
        model.model.layers.append(new_layer)
        model.config.num_hidden_layers += 1
    return model.get_submodule(layer_name)


In [123]:
HQQLinear.set_backend(HQQBackend.ATEN)

In [99]:
def load_quantized(p: Path, is_secondary_tqdm=True):
    raw_data = torch.load(p)
    for k, v in (bar := tqdm(raw_data.items(), leave=not is_secondary_tqdm)):
        raw_data[k] = None
        tensor_name_parts = k.split('.')
        bar.set_description(k)
        layer = ensure_layer_exists(model, k)
        if isinstance(v, torch.Tensor):
            new_state_dict = {tensor_name_parts[-1] : v}
            layer.load_state_dict(new_state_dict)
            continue
        assert isinstance(v, dict), "Expected state_dict of quantized layer"
        linear_holder_name = '.'.join(tensor_name_parts[:-2]) # :-1 tensor holder(linear) :-2 linear holder
        linear_holder = model.get_submodule(linear_holder_name)
        linear_name = tensor_name_parts[-2] # [-1]=.weight [-2] = .q_proj
        existing_linear = linear_holder.get_submodule(linear_name)
        if isinstance(existing_linear, (nn.Linear, HQQLinear)):
            # Need to replace
            hqq_linear = HQQLinear(None, quant_config=quant_cfg.quantconfig_from_tensor_name(k))
            hqq_linear.load_state_dict(v)
            setattr(linear_holder, linear_name, hqq_linear)
        else:
            raise ValueError(f"{k}: Expecting nn.Linear/HQQLinear, not {type(existing_linear)}")
    del raw_data
    gc.collect()
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.empty_cache()

In [101]:
for p in tqdm(quantized_files):
    load_quantized(p)

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/51 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

In [103]:
model.cuda() # 13GB

MixtralForCausalLM(
  (model): MixtralModel(
    (embed_tokens): Embedding(32002, 4096)
    (layers): ModuleList(
      (0-31): 32 x MixtralDecoderLayer(
        (self_attn): MixtralSdpaAttention(
          (q_proj): HQQLinear()
          (k_proj): HQQLinear()
          (v_proj): HQQLinear()
          (o_proj): HQQLinear()
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear(in_features=4096, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBlockSparseTop2MLP(
              (w1): HQQLinear()
              (w2): HQQLinear()
              (w3): HQQLinear()
              (act_fn): SiLU()
            )
          )
        )
        (input_layernorm): MixtralRMSNorm()
        (post_attention_layernorm): MixtralRMSNorm()
      )
    )
    (norm): MixtralRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32002, bias=False)
)

## Inference

In [104]:
from transformers import AutoTokenizer

In [105]:
tokenizer =  AutoTokenizer.from_pretrained(".")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [120]:
messages = [
    {"role": "user", "content": "Touhou. Please describe relationship between Remilia Scarlet and Sakuya Izayoi"},
]

In [129]:
encoded = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")
y = model.generate(encoded, max_new_tokens=128)
# With PYTORCH backend: 10m42s
# With ATEN backend: 4m08s

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:32000 for open-end generation.


In [130]:
print(tokenizer.decode(y.ravel()))


<|im_start|> user
Touhou. Please describe relationship between Remilia Scarlet and Sakuya Izayoi<|im_end|> 
<|im_start|> assistant
 Remilia Scarlet is the head of the Scarlet Devil Mansion and the leader of the Scarlet Devil Team. She is a vampire and the last of the Scarlet Devil lineage. On the other hand, Sakuya Izayoi is a servant of Remilia Scarlet and the head maid of the Scarlet Devil Mansion. She is a skilled martial artist and has the ability to manipulate time and space. The relationship between Remilia Scarlet and Sakuya Izayoi is that of a master and servant, with Remilia being the master and Sakuya being the servant
