In [1]:
import os
import time 
import torch
import torch.nn as nn
import GPUtil
from transformers.models.llama.modeling_llama import (
    LlamaAttention,
    LlamaDecoderLayer,
    LlamaForCausalLM,
    LlamaMLP,
)
from transformers import LlamaTokenizer, AutoTokenizer, AutoModelForCausalLM
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm, LlamaAttention
import tqdm
from functools import partial

  from .autonotebook import tqdm as notebook_tqdm


## Resource Detection

In [2]:
gpus = GPUtil.getGPUs()
free_memory = []

for gpu in gpus:
    free_memory.append(gpu.memoryFree)

memory_sort = sorted(range(len(free_memory)), key=lambda i: free_memory[i])

gpu_id = memory_sort[-1]
gpu_memory = free_memory[memory_sort[-1]]

print(f'gpu_id:{gpu_id}; gpu_memory:{gpu_memory}')

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

gpu_id:5; gpu_memory:49340.0


In [3]:
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"

## Model Selection

In [4]:
model_name = "/data/zbr/LLMs/Llama-2-7b-hf"
# model_name = "/data/llms/Qwen3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 42.36it/s]


In [5]:
### llama_2 layer importance: [27, 26, 25, 28, 24, 29, 23, 21, 22, 30, 19, 20, 18, 17, 16, 11, 12, 13, 14, 15, 10, 9, 8, 7, 6, 5, 3, 2, 4, 1, 31, 0]
### llama_3 layer importance: [23, 24, 25, 27, 26, 22, 28, 21, 20, 19, 18, 29, 17, 11, 13, 16, 10, 12, 9, 15, 8, 14, 3, 2, 7, 6, 4, 5, 30, 31, 1, 0]
Qwen_LI = [3, 7, 12, 2, 17, 11, 13, 16, 15, 4, 18, 20, 8, 21, 10, 9, 14, 32, 1, 33, 19, 23, 34, 31, 22, 30, 29, 24, 26, 5, 25, 28, 27, 6, 0, 35]
llama_LI = [27, 26, 25, 28, 24, 29, 23, 21, 22, 30, 19, 20, 18, 17, 16, 11, 12, 13, 14, 15, 10, 9, 8, 7, 6, 5, 3, 2, 4, 1, 31, 0]

layer_to_quant = llama_LI[:16]
# layer_to_quant = Qwen_LI[:18]

mlp_quant = [f'layers.{item}.mlp' for item in layer_to_quant]
self_attn_quant = [f'layers.{item}.self_attn' for item in layer_to_quant]

print(f'layers to quant:{layer_to_quant}')

layers to quant:[27, 26, 25, 28, 24, 29, 23, 21, 22, 30, 19, 20, 18, 17, 16, 11]


## Model Quantization

In [6]:
scales_dict = {}

In [7]:
def pack_to_uint8(quant_x, in_f, out_f):
    # out_f 为行，in_f 为列
    i = 0
    rounds = int(out_f/2)

    new_weight = torch.zeros(rounds, in_f, dtype=torch.uint8)

    while i < rounds:

        row0 = quant_x[i] 
        row1 = quant_x[i + rounds] 

        packed_row0_row1 = ((row0 + 8).to(torch.uint8) << 4) | (row1 + 8).to(torch.uint8)  # 第一列（高4位）+第三列（低4位）

        new_weight[i] = packed_row0_row1

        i += 1
    
    return new_weight

def unpack_to_int8(packed_x, in_f, out_f):

    i = 0
    rounds = int(out_f/2)

    new_weight = torch.zeros(out_f, in_f, dtype=torch.int8)

    while i < rounds:

        packed_row0_row1 = packed_x[i]
        row0 = ((packed_row0_row1 >> 4) & 0x0F).to(torch.int8) - 8  # 高4位→原第一列
        row1 = (packed_row0_row1 & 0x0F).to(torch.int8) - 8        # 低4位→原第三列

        new_weight[i] = row0
        new_weight[i + rounds] = row1

        i += 1
    
    return new_weight

In [8]:
@torch.no_grad()
def quantize_weight_per_channel_absmax(w, name_in, in_features, out_features, n_bits=8):
    # w: (out_features, in_features)
    scales = w.abs().max(dim=-1, keepdim=True)[0]
    q_max = 2 ** (n_bits - 1) - 1
    scales.clamp_(min=1e-5).div_(q_max)
    # w.div_(scales).round_().mul_(scales)
    w.div_(scales).round_()
    w = w.to(torch.int8)
    scales_dict[name_in] = scales
    if n_bits == 4:
        w = pack_to_uint8(w, in_features, out_features)
    return w


@torch.no_grad()
def quantize_weight_per_tensor_absmax(w, name_in, in_features, out_features, n_bits=8):
    # w: (out_features, in_features)
    scales = w.abs().max()
    q_max = 2 ** (n_bits - 1) - 1
    scales.clamp_(min=1e-5).div_(q_max)
    # w.div_(scales).round_().mul_(scales)
    w.div_(scales).round_()
    w = w.to(torch.int8)
    scales_dict[name_in] = scales
    if n_bits == 4:
        w = pack_to_uint8(w, in_features, out_features)
    return w

In [9]:
class LSAQLinear(nn.Module):
    def __init__(
        self,
        name_in,
        bit_width,
        in_features,
        out_features,
        bias=True,
        quantize_output=False,
    ):
        super().__init__()
        self.name_in = name_in
        self.bit_width = bit_width
        self.in_features = in_features
        self.out_features = out_features

        self.register_buffer(
            "weight",
            torch.randn(
                self.out_features,
                self.in_features,
                dtype=torch.float16,
                requires_grad=False,
            ),
        )
        if bias:
            self.register_buffer(
                "bias",
                torch.zeros(
                    (1, self.out_features), dtype=torch.float16, requires_grad=False
                ),
            )
        else:
            self.register_buffer("bias", None)

    def to(self, *args, **kwargs):
        super(LSAQLinear, self).to(*args, **kwargs)
        self.weight = self.weight.to(*args, **kwargs)
        if self.bias is not None:
            self.bias = self.bias.to(*args, **kwargs)
        return self

    @torch.no_grad()
    def forward(self, x):
        if self.bit_width == 4:
            weight = unpack_to_int8(self.weight, self.in_features, self.out_features)
            weight = weight.to('cuda')
            weight = weight.mul(pth_scales[self.name_in])
        else:
            weight = self.weight.mul(pth_scales[self.name_in])
            
        y = torch.functional.F.linear(x, weight, self.bias)
        return y

    @staticmethod
    def from_float(
        name_in, bit, module, weight_quant="per_channel", quantize_output=False
    ):
        assert isinstance(module, torch.nn.Linear)
        new_module = LSAQLinear(
            name_in,
            bit,
            module.in_features,
            module.out_features,
            module.bias is not None,
            quantize_output=quantize_output,
        )
        if weight_quant == "per_channel":
            new_module.weight = quantize_weight_per_channel_absmax(module.weight, name_in, module.in_features, module.out_features, bit)
        elif weight_quant == "per_tensor":
            new_module.weight = quantize_weight_per_tensor_absmax(module.weight, name_in, module.in_features, module.out_features, bit)
        else:
            raise ValueError(f"Invalid weight_quant: {weight_quant}")
        new_module.weight_quant_name = weight_quant
        if module.bias is not None:
            new_module.bias = module.bias
        return new_module

    def __repr__(self):
        return f"LSAQLinear({self.in_features}, {self.out_features}, bias={self.bias is not None}, weight_quant={self.weight_quant_name})"


In [10]:
def quantize_qwen_like(
    model, mlp_quant, self_attn_quant, low_bit, weight_quant="per_channel", quantize_bmm_input=False
):
    from transformers.models.qwen3.modeling_qwen3 import (
        Qwen3Attention,
        Qwen3MLP,
    )

    for name, m in model.model.named_modules():
        if isinstance(m, Qwen3MLP):
            print(name)
            if low_bit == 0:
                continue
            else:
                if name in mlp_quant:
                    bit = low_bit
                    print(f'{bit} bit quant')
                else:
                    if low_bit == 4:
                        bit = 8
                        print(f'{bit} bit quant')
                    elif low_bit == 8:
                        continue
            name_in = name + '.gate_proj'
            m.gate_proj = LSAQLinear.from_float(
                name_in, bit, m.gate_proj, weight_quant=weight_quant
            )
            name_in = name + '.up_proj'
            m.up_proj = LSAQLinear.from_float(
                name_in, bit, m.up_proj, weight_quant=weight_quant
            )
            name_in = name + '.down_proj'
            m.down_proj = LSAQLinear.from_float(
                name_in, bit, m.down_proj, weight_quant=weight_quant
            )
        elif isinstance(m, Qwen3Attention):
            # Her we simulate quantizing BMM inputs by quantizing the output of q_proj, k_proj, v_proj
            if low_bit == 0:
                    continue
            else:
                if name in self_attn_quant:
                    bit = low_bit
                else:
                    if low_bit == 4:
                        bit = 8
                    elif low_bit == 8:
                        continue
            name_in = name + '.q_proj'
            m.q_proj = LSAQLinear.from_float(
                name_in, 
                bit,  
                m.q_proj,
                weight_quant=weight_quant,
                quantize_output=quantize_bmm_input,
            )
            name_in = name + '.k_proj'
            m.k_proj = LSAQLinear.from_float(
                name_in, 
                bit, 
                m.k_proj,
                weight_quant=weight_quant,
                quantize_output=quantize_bmm_input,
            )
            name_in = name + '.v_proj'
            m.v_proj = LSAQLinear.from_float(
                name_in, 
                bit, 
                m.v_proj,
                weight_quant=weight_quant,
                quantize_output=quantize_bmm_input,
            )
            name_in = name + '.o_proj'
            m.o_proj = LSAQLinear.from_float(
                name_in, 
                bit, m.o_proj, weight_quant=weight_quant
            )
    return model


In [11]:
def quantize_llama_like(
    model, mlp_quant, self_attn_quant, low_bit, weight_quant="per_channel", quantize_bmm_input=False
):
    from transformers.models.llama.modeling_llama import (
        LlamaAttention,
        LlamaMLP,
    )

    for name, m in model.model.named_modules():
        if isinstance(m, LlamaMLP):
            print(name)
            if low_bit == 0:
                continue
            else:
                if name in mlp_quant:
                    bit = low_bit
                    print(f'{bit} bit quant')
                else:
                    if low_bit == 4:
                        bit = 8
                        print(f'{bit} bit quant')
                    elif low_bit == 8:
                        continue
            name_in = name + '.gate_proj'
            m.gate_proj = LSAQLinear.from_float(
                name_in, bit, m.gate_proj, weight_quant=weight_quant
            )
            name_in = name + '.up_proj'
            m.up_proj = LSAQLinear.from_float(
                name_in, bit, m.up_proj, weight_quant=weight_quant
            )
            name_in = name + '.down_proj'
            m.down_proj = LSAQLinear.from_float(
                name_in, bit, m.down_proj, weight_quant=weight_quant
            )
        elif isinstance(m, LlamaAttention):
            # Her we simulate quantizing BMM inputs by quantizing the output of q_proj, k_proj, v_proj
            if low_bit == 0:
                    continue
            else:
                if name in self_attn_quant:
                    bit = low_bit
                else:
                    if low_bit == 4:
                        bit = 8
                    elif low_bit == 8:
                        continue
            name_in = name + '.q_proj'
            m.q_proj = LSAQLinear.from_float(
                name_in, 
                bit,  
                m.q_proj,
                weight_quant=weight_quant,
                quantize_output=quantize_bmm_input,
            )
            name_in = name + '.k_proj'
            m.k_proj = LSAQLinear.from_float(
                name_in, 
                bit, 
                m.k_proj,
                weight_quant=weight_quant,
                quantize_output=quantize_bmm_input,
            )
            name_in = name + '.v_proj'
            m.v_proj = LSAQLinear.from_float(
                name_in, 
                bit, 
                m.v_proj,
                weight_quant=weight_quant,
                quantize_output=quantize_bmm_input,
            )
            name_in = name + '.o_proj'
            m.o_proj = LSAQLinear.from_float(
                name_in, 
                bit, m.o_proj, weight_quant=weight_quant
            )
    return model

In [12]:
model_aqi = quantize_llama_like(model, mlp_quant, self_attn_quant, 4)
# model_aqi = quantize_qwen_like(model, mlp_quant, self_attn_quant, 8)


layers.0.mlp
8 bit quant
layers.1.mlp
8 bit quant
layers.2.mlp
8 bit quant
layers.3.mlp
8 bit quant
layers.4.mlp
8 bit quant
layers.5.mlp
8 bit quant
layers.6.mlp
8 bit quant
layers.7.mlp
8 bit quant
layers.8.mlp
8 bit quant
layers.9.mlp
8 bit quant
layers.10.mlp
8 bit quant
layers.11.mlp
4 bit quant
layers.12.mlp
8 bit quant
layers.13.mlp
8 bit quant
layers.14.mlp
8 bit quant
layers.15.mlp
8 bit quant
layers.16.mlp
4 bit quant
layers.17.mlp
4 bit quant
layers.18.mlp
4 bit quant
layers.19.mlp
4 bit quant
layers.20.mlp
4 bit quant
layers.21.mlp
4 bit quant
layers.22.mlp
4 bit quant
layers.23.mlp
4 bit quant
layers.24.mlp
4 bit quant
layers.25.mlp
4 bit quant
layers.26.mlp
4 bit quant
layers.27.mlp
4 bit quant
layers.28.mlp
4 bit quant
layers.29.mlp
4 bit quant
layers.30.mlp
4 bit quant
layers.31.mlp
8 bit quant


In [13]:
torch.save(scales_dict, 'quantized_weight.pth')
pth_scales = torch.load('quantized_weight.pth')
model_aqi.cuda()
for key in pth_scales.keys():
    pth_scales[key] = pth_scales[key].to('cuda')

In [14]:
# prompt = "Hey, are you conscious? Can you talk to me?"
# model_aqi = model
# model_aqi.cuda()

prompt = "Where is the capital city of America"
inputs = tokenizer(prompt, return_tensors="pt").to(model_aqi.device)

# Generate
start_time = time.time()
generate_ids = model_aqi.generate(inputs.input_ids, max_length=20)
end_time = time.time()
speed = len(generate_ids[0])/(end_time-start_time)

print(tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])

print(f"speed:{speed:.2f}token/s max memory:{torch.cuda.max_memory_allocated(model_aqi.device)/ 1024**2:.2f}M")

Where is the capital city of America?
 hopefully you know the answer is Washington D.C
speed:0.01token/s max memory:5354.14M
