In [2]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import Parameter 
import torch.nn.functional as F

# import tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

# test generate
inputs = tokenizer("Hello, world is ", return_tensors="pt")
outputs = model.generate(**inputs, max_length=68)
print(tokenizer.decode(outputs[0]))


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Hello, world is  going to be a lot more interesting than it was before.
I'm not sure if I'm going to be able to do this, but I'm going to be able to do it.
I'm going to be able to do it.
I'm going to be able to do it.



In [3]:
inputs

{'input_ids': tensor([[15496,    11,   995,   318,   220]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}

In [4]:
for name, mod in model.named_modules():
    print(name, mod.__class__.__name__)

 GPT2LMHeadModel
transformer GPT2Model
transformer.wte Embedding
transformer.wpe Embedding
transformer.drop Dropout
transformer.h ModuleList
transformer.h.0 GPT2Block
transformer.h.0.ln_1 LayerNorm
transformer.h.0.attn GPT2Attention
transformer.h.0.attn.c_attn Conv1D
transformer.h.0.attn.c_proj Conv1D
transformer.h.0.attn.attn_dropout Dropout
transformer.h.0.attn.resid_dropout Dropout
transformer.h.0.ln_2 LayerNorm
transformer.h.0.mlp GPT2MLP
transformer.h.0.mlp.c_fc Conv1D
transformer.h.0.mlp.c_proj Conv1D
transformer.h.0.mlp.act NewGELUActivation
transformer.h.0.mlp.dropout Dropout
transformer.h.1 GPT2Block
transformer.h.1.ln_1 LayerNorm
transformer.h.1.attn GPT2Attention
transformer.h.1.attn.c_attn Conv1D
transformer.h.1.attn.c_proj Conv1D
transformer.h.1.attn.attn_dropout Dropout
transformer.h.1.attn.resid_dropout Dropout
transformer.h.1.ln_2 LayerNorm
transformer.h.1.mlp GPT2MLP
transformer.h.1.mlp.c_fc Conv1D
transformer.h.1.mlp.c_proj Conv1D
transformer.h.1.mlp.act NewGELUActiva

In [5]:
import yaml

with open("../configs/config.yaml", "r") as f:
    cfg = yaml.safe_load(f)

print(cfg["default_w_bits"])       
print(cfg["per_layer_bits"].keys())


8
dict_keys(['transformer.h.0.attn.c_attn', 'transformer.h.0.attn.c_proj', 'transformer.h.0.mlp.c_fc', 'transformer.h.0.mlp.c_proj', 'transformer.h.1.attn.c_attn', 'transformer.h.1.attn.c_proj', 'transformer.h.1.mlp.c_fc', 'transformer.h.1.mlp.c_proj', 'transformer.h.2.attn.c_attn', 'transformer.h.2.attn.c_proj', 'transformer.h.2.mlp.c_fc', 'transformer.h.2.mlp.c_proj', 'transformer.h.3.attn.c_attn', 'transformer.h.3.attn.c_proj', 'transformer.h.3.mlp.c_fc', 'transformer.h.3.mlp.c_proj', 'transformer.h.4.attn.c_attn', 'transformer.h.4.attn.c_proj', 'transformer.h.4.mlp.c_fc', 'transformer.h.4.mlp.c_proj', 'transformer.h.5.attn.c_attn', 'transformer.h.5.attn.c_proj', 'transformer.h.5.mlp.c_fc', 'transformer.h.5.mlp.c_proj', 'transformer.h.6.attn.c_attn', 'transformer.h.6.attn.c_proj', 'transformer.h.6.mlp.c_fc', 'transformer.h.6.mlp.c_proj', 'transformer.h.7.attn.c_attn', 'transformer.h.7.attn.c_proj', 'transformer.h.7.mlp.c_fc', 'transformer.h.7.mlp.c_proj', 'transformer.h.8.attn.c_att

In [6]:
import re
import torch.nn as nn

def want_quant(name, mod, cfg):
    # skip embedding / norm / lm_head
    # only Linear or Conv1D
    if name == "lm_head": 
        return False
    if isinstance(mod, nn.Linear) or mod.__class__.__name__ == "Conv1D":
        return True
    return False

In [7]:
list(model.named_modules())[10][1].weight.shape
a,b = list(model.named_modules())[9][1].weight.shape
a,b

(768, 2304)

In [8]:
class QuantLinear(nn.Module):
    r"""Quantized version of nn.Linear.

    This layer works just like nn.Linear, but the weights are stored
    in int8 format to save memory and improve efficiency.

    Instead of a single global scale, each output channel has its own
    scale factor. This makes the quantization more accurate because
    different output channels can have very different weight ranges.

    Input shape:  (*, in_features)
    Output shape: (*, out_features)
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.register_buffer("qweight",
            torch.empty(out_features, in_features, dtype=torch.int8, device=device))
        self.register_buffer("w_scale",
            torch.ones(out_features, dtype=torch.float32, device=device))
        self.register_buffer("w_zp",
            torch.zeros(out_features, dtype=torch.int32, device=device))
        self.register_buffer("fp32_weight", None)  # orignal
        self.current_bits = None  # current 
        if bias:
            self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter("bias", None)

    def store_fp32_weight(self, weight: torch.Tensor):
        self.fp32_weight = weight.clone()

    def requantize_to_bits(self, bits: int):
        if self.current_bits == bits:
            return 
        
        self.quantize_from_float(self.fp32_weight, bits=bits)
        self.current_bits = bits


    def forward(self, input: Tensor) -> Tensor:
        if torch.any(self.w_zp != 0):
            # non-zero： (q - zp) * scale
            W = (self.qweight.int() - self.w_zp.view(-1, 1)).float() * self.w_scale.view(-1, 1)
        else:
            # zero w zp： q * scale
            W = self.qweight.float() * self.w_scale.view(-1, 1)
        return F.linear(input, W, self.bias)

    def extra_repr(self) -> str:
        """
        Return the extra representation of the quant module.
        """
        return (f"in_features={self.in_features}, out_features={self.out_features}, "
                f"bias={self.bias is not None}, dtype=int8, per_channel=True")
                
    @staticmethod
    def get_bits_for_layer(name: str, cfg: dict) -> int:
        return cfg["per_layer_bits"][name]

    def quantize_from_float(self, weight: torch.Tensor, bits: int = 8):
        # support 2-8 bits
        qmin, qmax = -(2**(bits-1)), 2**(bits-1) - 1 
        w_max_abs = weight.abs().max(dim=1, keepdim=True)[0]
        w_max_abs = torch.clamp(w_max_abs, min=1e-8)
        scale = w_max_abs / qmax
        qweight = torch.clamp(torch.round(weight / scale), qmin, qmax).to(torch.int8)
        zero_point = torch.zeros(weight.size(0), dtype=torch.int32, device=weight.device)
        self.qweight.copy_(qweight)
        self.w_scale.copy_(scale.squeeze())
        self.w_zp.copy_(zero_point)

    @classmethod
    def from_linear(cls, base: nn.Linear, name: str, cfg: dict):
        bits = cls.get_bits_for_layer(name, cfg)
        q = cls(base.in_features, base.out_features,
                bias=(base.bias is not None),
                device=base.weight.device, dtype=base.weight.dtype)
        with torch.no_grad():
            
            q.store_fp32_weight(base.weight)

            bits = cls.get_bits_for_layer(name, cfg)
            q.quantize_from_float(base.weight, bits=bits)
            q.current_bits = bits
            
            if base.bias is not None:
                q.bias.copy_(base.bias)
        return q


In [9]:
print(QuantLinear.get_bits_for_layer("transformer.h.0.attn.c_attn", cfg))   


8


In [10]:
def requantize_model_to_config(model, cfg):
        
    default_bits = cfg.get('default_w_bits', 8)
    per_layer_bits = cfg.get('per_layer_bits', {})
    
    for name, module in model.named_modules():
        if isinstance(module, QuantLinear):
            target_bits = per_layer_bits.get(name, default_bits)
            module.requantize_to_bits(target_bits)

In [11]:
def replace_with_quant(model, cfg):
    name_to_module = dict(model.named_modules())
    for name, mod in list(name_to_module.items()):
        print(name, mod.__class__.__name__)
        if not want_quant(name, mod, cfg):
            continue

        # parent module location
        if '.' in name:
            parent_name, child_name = name.rsplit('.', 1)
            parent = name_to_module[parent_name]
        else:
            parent, child_name = model, name

        # convert Conv1D to Linear
        if mod.__class__.__name__ == "Conv1D":
                in_f, out_f = mod.weight.shape          # Conv1D weight is (out, in)
                base = nn.Linear(in_f, out_f, bias=(mod.bias is not None))
                base.to(mod.weight.device, dtype=mod.weight.dtype)
                with torch.no_grad():
                    base.weight.copy_(mod.weight.T)       # copy 
                    if mod.bias is not None:
                        base.bias.copy_(mod.bias)
        else:
            base = mod

        # construct quantization wrapper
        qcfg = cfg.copy()
        qmod = QuantLinear.from_linear(base, name,cfg=qcfg)

        # replace the original layer with setattr
        setattr(parent, child_name, qmod)

replace_with_quant(model, cfg)


 GPT2LMHeadModel
transformer GPT2Model
transformer.wte Embedding
transformer.wpe Embedding
transformer.drop Dropout
transformer.h ModuleList
transformer.h.0 GPT2Block
transformer.h.0.ln_1 LayerNorm
transformer.h.0.attn GPT2Attention
transformer.h.0.attn.c_attn Conv1D
transformer.h.0.attn.c_proj Conv1D
transformer.h.0.attn.attn_dropout Dropout
transformer.h.0.attn.resid_dropout Dropout
transformer.h.0.ln_2 LayerNorm
transformer.h.0.mlp GPT2MLP
transformer.h.0.mlp.c_fc Conv1D
transformer.h.0.mlp.c_proj Conv1D
transformer.h.0.mlp.act NewGELUActivation
transformer.h.0.mlp.dropout Dropout
transformer.h.1 GPT2Block
transformer.h.1.ln_1 LayerNorm
transformer.h.1.attn GPT2Attention
transformer.h.1.attn.c_attn Conv1D
transformer.h.1.attn.c_proj Conv1D
transformer.h.1.attn.attn_dropout Dropout
transformer.h.1.attn.resid_dropout Dropout
transformer.h.1.ln_2 LayerNorm
transformer.h.1.mlp GPT2MLP
transformer.h.1.mlp.c_fc Conv1D
transformer.h.1.mlp.c_proj Conv1D
transformer.h.1.mlp.act NewGELUActiva

In [29]:
# Step 1 test - fresh model with original methods
test_model = GPT2LMHeadModel.from_pretrained("gpt2")
replace_with_quant(test_model, cfg)

# Test 2bit/6bit requantization  
with open("../configs/test_2bit_6bit.yaml", "r") as f:
    test_cfg = yaml.safe_load(f)
requantize_model_to_config(test_model, test_cfg)

# Generation test
inputs = tokenizer("Hello, world is ", return_tensors="pt")
with torch.no_grad():
    outputs = test_model.generate(**inputs, max_length=30)
print("Result:", tokenizer.decode(outputs[0]))

 GPT2LMHeadModel
transformer GPT2Model
transformer.wte Embedding
transformer.wpe Embedding
transformer.drop Dropout
transformer.h ModuleList
transformer.h.0 GPT2Block
transformer.h.0.ln_1 LayerNorm
transformer.h.0.attn GPT2Attention
transformer.h.0.attn.c_attn Conv1D
transformer.h.0.attn.c_proj Conv1D
transformer.h.0.attn.attn_dropout Dropout
transformer.h.0.attn.resid_dropout Dropout
transformer.h.0.ln_2 LayerNorm
transformer.h.0.mlp GPT2MLP
transformer.h.0.mlp.c_fc Conv1D
transformer.h.0.mlp.c_proj Conv1D
transformer.h.0.mlp.act NewGELUActivation
transformer.h.0.mlp.dropout Dropout
transformer.h.1 GPT2Block
transformer.h.1.ln_1 LayerNorm
transformer.h.1.attn GPT2Attention
transformer.h.1.attn.c_attn Conv1D
transformer.h.1.attn.c_proj Conv1D
transformer.h.1.attn.attn_dropout Dropout
transformer.h.1.attn.resid_dropout Dropout
transformer.h.1.ln_2 LayerNorm
transformer.h.1.mlp GPT2MLP
transformer.h.1.mlp.c_fc Conv1D
transformer.h.1.mlp.c_proj Conv1D
transformer.h.1.mlp.act NewGELUActiva

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


transformer.h.10.mlp.act NewGELUActivation
transformer.h.10.mlp.dropout Dropout
transformer.h.11 GPT2Block
transformer.h.11.ln_1 LayerNorm
transformer.h.11.attn GPT2Attention
transformer.h.11.attn.c_attn Conv1D
transformer.h.11.attn.c_proj Conv1D
transformer.h.11.attn.attn_dropout Dropout
transformer.h.11.attn.resid_dropout Dropout
transformer.h.11.ln_2 LayerNorm
transformer.h.11.mlp GPT2MLP
transformer.h.11.mlp.c_fc Conv1D
transformer.h.11.mlp.c_proj Conv1D
transformer.h.11.mlp.act NewGELUActivation
transformer.h.11.mlp.dropout Dropout
transformer.ln_f LayerNorm
lm_head Linear
Result: Hello, world is  in the world. 
I've been in the a small small small small small


In [12]:

inputs = tokenizer("Hello, world is ", return_tensors="pt")
with torch.no_grad():
    outputs = model.generate(**inputs, max_length=68)

print(tokenizer.decode(outputs[0]))


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Hello, world is  going to be a lot more interesting than it was before.
I'm not sure if I'm going to be able to do this, but I'm going to be able to do it.
I'm going to be able to do it.
I'm going to be able to do it.



In [13]:
import torch
import torch.nn as nn

class LoRA(nn.Module):
    def __init__(self, in_f, out_f, r=4, alpha=None):
        super().__init__()
        self.scale = (alpha or r) / r
        self.A = nn.Parameter(torch.randn(r, in_f) * 0.01)
        self.B = nn.Parameter(torch.zeros(out_f, r))
    def forward(self, x):
        return (x @ self.A.t()) @ self.B.t() * self.scale

class LoRAWrapped(nn.Module):
    def __init__(self, base, branches, layer_name=None):
        super().__init__()
        self.base = base
        for p in self.base.parameters():
            p.requires_grad = False

        # keep device and precision
        dev = next(self.base.parameters()).device
        dtype = next(self.base.parameters()).dtype

        in_f = self.base.in_features
        out_f = self.base.out_features

        self.bank = nn.ModuleDict({
            k: LoRA(in_f, out_f, r, a).to(device=dev, dtype=dtype)
            for k, (r, a) in branches.items()
        })
        self.active = None            # only one branch
        self.layer_name = layer_name  # for logging / routing (optional)

    def set_active(self, name_or_none):
        self.active = name_or_none   # name_or_none: 'bw4' / 'bw8' / None

    def forward(self, x):
        y = self.base(x)
        if self.active in self.bank:
            y = y + self.bank[self.active](x)
        return y

def attach_lora_to_quant(model, name2branches, quant_cfg):
    """
    Lora and quant at the same time
    """
    # quant
    replace_with_quant(model, quant_cfg)
    
    # add LoRA
    wrappers = {}
    for name, mod in list(model.named_modules()):
        if name in name2branches and hasattr(mod, 'in_features'):
            parent = model.get_submodule(name.rsplit('.',1)[0]) if '.' in name else model
            attr = name.split('.')[-1]
            w = LoRAWrapped(mod, name2branches[name], layer_name=name)
            setattr(parent, attr, w)
            wrappers[name] = w
    return wrappers

# activate by bit config (call this before inference)
def activate_lora_by_bits(wrappers, bit_cfg, default_bits=None):
    m = {4: "bw4", 8: "bw8"}  # 4/8
    for n, w in wrappers.items():
        bw = bit_cfg.get(n, default_bits)
        w.set_active(m[bw])  


In [14]:
# define the quantization bit width and LoRA branches for each layer
lora_spec = {
    "transformer.h.0.attn.c_attn": {"bw4": (8,16), "bw8": (4,8)},
    "transformer.h.1.attn.c_attn": {"bw4": (8,16), "bw8": (4,8)},
    # ... more layers
}

# one-step: quant + LoRA
wrappers = attach_lora_to_quant(model, lora_spec, cfg)

# 
for name, wrapper in wrappers.items():
    bits = cfg["per_layer_bits"].get(name, cfg["default_w_bits"])
    if bits <= 4:
        wrapper.set_active("bw4")
    else:   
        wrapper.set_active("bw8")

 GPT2LMHeadModel
transformer GPT2Model
transformer.wte Embedding
transformer.wpe Embedding
transformer.drop Dropout
transformer.h ModuleList
transformer.h.0 GPT2Block
transformer.h.0.ln_1 LayerNorm
transformer.h.0.attn GPT2Attention
transformer.h.0.attn.c_attn QuantLinear
transformer.h.0.attn.c_proj QuantLinear
transformer.h.0.attn.attn_dropout Dropout
transformer.h.0.attn.resid_dropout Dropout
transformer.h.0.ln_2 LayerNorm
transformer.h.0.mlp GPT2MLP
transformer.h.0.mlp.c_fc QuantLinear
transformer.h.0.mlp.c_proj QuantLinear
transformer.h.0.mlp.act NewGELUActivation
transformer.h.0.mlp.dropout Dropout
transformer.h.1 GPT2Block
transformer.h.1.ln_1 LayerNorm
transformer.h.1.attn GPT2Attention
transformer.h.1.attn.c_attn QuantLinear
transformer.h.1.attn.c_proj QuantLinear
transformer.h.1.attn.attn_dropout Dropout
transformer.h.1.attn.resid_dropout Dropout
transformer.h.1.ln_2 LayerNorm
transformer.h.1.mlp GPT2MLP
transformer.h.1.mlp.c_fc QuantLinear
transformer.h.1.mlp.c_proj QuantLine

In [15]:
from datasets import load_dataset
from torch.utils.data import DataLoader
import random

# Load dataset
squad_dataset = load_dataset("squad", split="train")

# Format dataset to prompt
def format_squad_prompt(sample):
    return f"question: {sample['question']} context: {sample['context']} answer: {sample['answers']['text'][0]}"

# Create a small subset
subset_indices = random.sample(range(len(squad_dataset)), 1000)
squad_subset = squad_dataset.select(subset_indices)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    
# DataLoader
def collate_fn(batch):
    prompts = [format_squad_prompt(s) for s in batch]
    return tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=128)

train_dataloader = DataLoader(squad_subset, batch_size=4, shuffle=True, collate_fn=collate_fn)

In [16]:
# Load two config files directly into a list
with open("../configs/config.yaml", 'r') as f: config_A = yaml.safe_load(f)
with open("../configs/config_4bit.yaml", 'r') as f: config_B = yaml.safe_load(f)

precision_configs = [config_A, config_B]

# Collect all LoRA module parameters
lora_params = [
    p for w in wrappers.values() for p in w.bank.parameters() if p.requires_grad
]
optimizer = torch.optim.AdamW(lora_params, lr=1e-4)

In [17]:
import random
from tqdm import tqdm

model.train()  # Set to training mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Use tqdm to create a progress bar
progress_bar = tqdm(range(1000))

# Get data from dataloader
data_iter = iter(train_dataloader)

for i in progress_bar:
    # If data is used up, create a new iterator
    try:
        batch = next(data_iter)
    except StopIteration:
        data_iter = iter(train_dataloader)
        batch = next(data_iter)

    # Randomly select a precision configuration
    chosen_config = random.choice(precision_configs)

    # Activate the corresponding LoRA branch based on the selected configuration
    per_layer_config = chosen_config.get('per_layer_bits', {})
    default_bits = chosen_config.get('default_w_bits')
    activate_lora_by_bits(wrappers, per_layer_config, default_bits)

    #training process
    inputs = {k: v.to(device) for k, v in batch.items()}
    outputs = model(**inputs, labels=inputs["input_ids"])
    loss = outputs.loss

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    # Update the progress bar display
    progress_bar.set_description(f"Iteration {i+1} | Loss: {loss.item():.3f} | Config: {chosen_config['default_w_bits']}-bit")

print("\nTraining completed!")

  0%|          | 0/1000 [00:00<?, ?it/s]`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.
Iteration 1000 | Loss: 3.854 | Config: 8-bit: 100%|██████████| 1000/1000 [00:43<00:00, 22.96it/s]


Training completed!





In [18]:
# --- Quick smoke test for 2 configs ---
model.eval()
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token

q = "What is the capital of France?"
c = "France is a country in Western Europe. Its capital and largest city is Paris."
prompt = f"question: {q} context: {c} answer:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)

@torch.no_grad()
def gen_with(cfg):
    activate_lora_by_bits(wrappers, cfg.get('per_layer_bits', {}), cfg.get('default_w_bits'))
    out = model.generate(
        **inputs,
        max_new_tokens=32, do_sample=False,  # 更稳定
        pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id
    )
    gen = out[0, inputs['input_ids'].shape[1]:]      # 只取新生成部分
    return tokenizer.decode(gen, skip_special_tokens=True).strip()

print("[A]", cfg_A_name := config_A.get("name", f"default{config_A.get('default_w_bits')}"), "=>", gen_with(config_A))
print("[B]", cfg_B_name := config_B.get("name", f"default{config_B.get('default_w_bits')}"), "=>", gen_with(config_B))


[A] default8 => France is a country in Western Europe. Its capital and largest city is Paris. Its capital and largest city is Paris. Its capital and largest city is Paris.
[B] default4 => France is a country in Western Europe. Its capital and largest city is Paris. Its capital and largest city is Paris. Its capital and largest city is Paris.


In [19]:
# cell 16 (修改后)
import math, time, torch, random, gc
from datasets import load_dataset
from torch.utils.data import DataLoader

# --- 1. 在 Cell 的最开始就定义好全局设备 ---
# 这样后续所有操作都会默认使用这个设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"评估将使用设备: {device}")

model.to(device) # 确保模型在正确的设备上
model.eval()

# 2) 准备验证集 (这部分不变)
val_ds = load_dataset("squad", split="validation")
def fmt(s): return f"question: {s['question']} context: {s['context']} answer: {s['answers']['text'][0]}"
idx = random.sample(range(len(val_ds)), 2000) # 用一个更小的子集快速测试
val_ds = val_ds.select(idx)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def collate(batch):
    prompts = [fmt(x) for x in batch]
    return tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=128)

val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, collate_fn=collate)

# 3) 激活 LoRA 分支 (这部分不变)
def set_bits(cfg):
    per_layer = cfg.get("per_layer_bits", {})
    default_b = cfg.get("default_w_bits", 8)
    activate_lora_by_bits(wrappers, per_layer, default_b)
    return default_b

# 4) 单配置评估 (修改内存统计部分)
@torch.no_grad()
def eval_config(cfg, max_batches=None):
    requantize_model_to_config(model, cfg)
    bits = set_bits(cfg)
    model.eval()
    

    # --- 明确指定要监控的设备 ---
    if device.type == 'cuda':
        torch.cuda.synchronize(device)
        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.reset_peak_memory_stats(device)

    tot_loss, tot_tok, tot_time, seen = 0.0, 0, 0.0, 0

    for i, batch in tqdm(enumerate(val_loader), total=len(val_loader), desc=f"Eval {bits}-bit"):
        if (max_batches is not None) and (i >= max_batches): break

        # inputs 已经在这里被移动到 device
        inputs = {k: v.to(device) for k, v in batch.items()}

        if device.type == 'cuda': torch.cuda.synchronize(device)
        t0 = time.time()
        out = model(**inputs, labels=inputs["input_ids"])
        if device.type == 'cuda': torch.cuda.synchronize(device)
        tot_time += (time.time() - t0)

        bs = inputs["input_ids"].size(0)
        tot_loss += out.loss.item() * bs
        tot_tok  += inputs["input_ids"].numel()
        seen     += bs

    # --- 在循环结束后，从正确的设备读取内存峰值 ---
    peak = torch.cuda.max_memory_allocated(device) if device.type == 'cuda' else 0
    
    avg_loss = tot_loss / max(seen, 1)
    ppl = math.exp(avg_loss) if avg_loss < 20 else float("inf")
    tps = tot_tok / max(tot_time, 1e-6)
    mem_mb = peak / (1024**2)

    # 给配置一个名字，如果 YAML 里没有的话
    config_name = cfg.get("name", f"config_default_{cfg.get('default_w_bits','N/A')}")
    print(f"[{config_name}] bits={bits} | loss={avg_loss:.3f} | ppl={ppl:.2f} | tokens/s={tps:.0f} | peak={mem_mb:.0f}MB")
    return {"name": config_name, "bits": bits, "loss": avg_loss, "ppl": ppl, "tps": tps, "memMB": mem_mb}

# 5) 跑两种配置 (不变)
results = []
# 确保你的 config_A 和 config_B 在之前的 cell 中被正确定义
# 比如在 config.yaml 和 config_4bit.yaml 中分别添加一行 'name: config_8bit' 和 'name: config_4bit'
for cfg in [config_A, config_B]:
    torch.cuda.reset_peak_memory_stats(device)
    results.append(eval_config(cfg))

# 6) 打印排行榜 (不变)
results.sort(key=lambda x: x["ppl"])
print("\n== Leaderboard (by PPL) ==")
for r in results:
    print(f"{r['name']:>15} | ppl={r['ppl']:.2f} | loss={r['loss']:.3f} | mem={r['memMB']:.0f}MB | tps={r['tps']:.0f}")

评估将使用设备: cuda


Eval 8-bit: 100%|██████████| 250/250 [00:09<00:00, 26.75it/s]


[config_default_8] bits=8 | loss=3.468 | ppl=32.06 | tokens/s=35365 | peak=1620MB


Eval 4-bit: 100%|██████████| 250/250 [00:09<00:00, 27.61it/s]

[config_default_4] bits=4 | loss=3.719 | ppl=41.22 | tokens/s=35353 | peak=1620MB

== Leaderboard (by PPL) ==
config_default_8 | ppl=32.06 | loss=3.468 | mem=1620MB | tps=35365
config_default_4 | ppl=41.22 | loss=3.719 | mem=1620MB | tps=35353





In [20]:
def requantize_model_to_config(model, cfg):
    default_bits = cfg.get('default_w_bits', 8)
    per_layer_bits = cfg.get('per_layer_bits', {})
    
    print(f"🔧 Requantizing model to {default_bits}-bit...")
    count = 0
    for name, module in model.named_modules():
        if isinstance(module, QuantLinear):
            target_bits = per_layer_bits.get(name, default_bits)
            old_bits = getattr(module, 'current_bits', 'unknown')
            module.requantize_to_bits(target_bits)
            print(f"  {name}: {old_bits} -> {target_bits} bits")
            count += 1
    print(f"✅ Requantized {count} layers")

In [None]:
import glob, os, yaml, torch, evaluate, pandas as pd
from tqdm import tqdm

model.to(device); model.eval()
metric = evaluate.load("squad")
cfg_paths = sorted(glob.glob("../configs/*.yaml"))

def em_f1_for_cfg(cfg):
    requantize_model_to_config(model, cfg)
    set_bits(cfg)

    preds, refs = [], []
    subset = val_ds.select(range(min(n, len(val_ds))))
    for ex in tqdm(subset, leave=False):
        prompt = f"question: {ex['question']} context: {ex['context']} answer:"
        inp = tokenizer(prompt, return_tensors="pt").to(device)
        out = model.generate(**inp, max_new_tokens=30, pad_token_id=tokenizer.eos_token_id)
        gen_ids = out[0, inp['input_ids'].size(1):]
        ans = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
        preds.append({"id": ex["id"], "prediction_text": ans})
        refs.append({"id": ex["id"], "answers": ex["answers"]})
    return metric.compute(predictions=preds, references=refs)

rows = []
for p in cfg_paths:
    with open(p, "r") as f:
        cfg = yaml.safe_load(f)
    name = cfg.get("name", os.path.basename(p).replace(".yaml",""))
    res = em_f1_for_cfg(cfg)
    rows.append({
        "config": name,
        "EM": res["exact_match"],
        "F1": res["f1"],
        "default_bits": cfg.get("default_w_bits","-")
    })

df = pd.DataFrame(rows).sort_values("F1", ascending=False).reset_index(drop=True)
print(df.to_string(index=False))

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]

🔧 Requantizing model to 4-bit...
  transformer.h.0.attn.c_attn.base: 4 -> 4 bits
  transformer.h.0.attn.c_proj: 4 -> 8 bits
  transformer.h.0.mlp.c_fc: 4 -> 8 bits
  transformer.h.0.mlp.c_proj: 4 -> 8 bits
  transformer.h.1.attn.c_attn.base: 4 -> 4 bits
  transformer.h.1.attn.c_proj: 4 -> 4 bits
  transformer.h.1.mlp.c_fc: 4 -> 4 bits
  transformer.h.1.mlp.c_proj: 4 -> 4 bits
  transformer.h.2.attn.c_attn: 4 -> 4 bits
  transformer.h.2.attn.c_proj: 4 -> 4 bits
  transformer.h.2.mlp.c_fc: 4 -> 4 bits
  transformer.h.2.mlp.c_proj: 4 -> 4 bits
  transformer.h.3.attn.c_attn: 4 -> 4 bits
  transformer.h.3.attn.c_proj: 4 -> 4 bits
  transformer.h.3.mlp.c_fc: 4 -> 4 bits
  transformer.h.3.mlp.c_proj: 4 -> 4 bits
  transformer.h.4.attn.c_attn: 4 -> 4 bits
  transformer.h.4.attn.c_proj: 4 -> 4 bits
  transformer.h.4.mlp.c_fc: 4 -> 4 bits
  transformer.h.4.mlp.c_proj: 4 -> 4 bits
  transformer.h.5.attn.c_attn: 4 -> 4 bits
  transformer.h.5.attn.c_proj: 4 -> 4 bits
  transformer.h.5.mlp.c_fc: 4 -

                                                 

🔧 Requantizing model to 8-bit...
  transformer.h.0.attn.c_attn.base: 4 -> 8 bits
  transformer.h.0.attn.c_proj: 8 -> 8 bits
  transformer.h.0.mlp.c_fc: 8 -> 8 bits
  transformer.h.0.mlp.c_proj: 8 -> 8 bits
  transformer.h.1.attn.c_attn.base: 4 -> 8 bits
  transformer.h.1.attn.c_proj: 4 -> 8 bits
  transformer.h.1.mlp.c_fc: 4 -> 8 bits
  transformer.h.1.mlp.c_proj: 4 -> 8 bits
  transformer.h.2.attn.c_attn: 4 -> 8 bits
  transformer.h.2.attn.c_proj: 4 -> 8 bits
  transformer.h.2.mlp.c_fc: 4 -> 8 bits
  transformer.h.2.mlp.c_proj: 4 -> 8 bits
  transformer.h.3.attn.c_attn: 4 -> 8 bits
  transformer.h.3.attn.c_proj: 4 -> 8 bits
  transformer.h.3.mlp.c_fc: 4 -> 8 bits
  transformer.h.3.mlp.c_proj: 4 -> 8 bits
  transformer.h.4.attn.c_attn: 4 -> 8 bits
  transformer.h.4.attn.c_proj: 4 -> 8 bits
  transformer.h.4.mlp.c_fc: 4 -> 8 bits
  transformer.h.4.mlp.c_proj: 4 -> 8 bits
  transformer.h.5.attn.c_attn: 4 -> 8 bits
  transformer.h.5.attn.c_proj: 4 -> 8 bits
  transformer.h.5.mlp.c_fc: 4 -

                                                 

🔧 Requantizing model to 4-bit...
  transformer.h.0.attn.c_attn.base: 8 -> 4 bits
  transformer.h.0.attn.c_proj: 8 -> 4 bits
  transformer.h.0.mlp.c_fc: 8 -> 4 bits
  transformer.h.0.mlp.c_proj: 8 -> 4 bits
  transformer.h.1.attn.c_attn.base: 8 -> 4 bits
  transformer.h.1.attn.c_proj: 8 -> 4 bits
  transformer.h.1.mlp.c_fc: 8 -> 4 bits
  transformer.h.1.mlp.c_proj: 8 -> 4 bits
  transformer.h.2.attn.c_attn: 8 -> 4 bits
  transformer.h.2.attn.c_proj: 8 -> 4 bits
  transformer.h.2.mlp.c_fc: 8 -> 4 bits
  transformer.h.2.mlp.c_proj: 8 -> 4 bits
  transformer.h.3.attn.c_attn: 8 -> 4 bits
  transformer.h.3.attn.c_proj: 8 -> 4 bits
  transformer.h.3.mlp.c_fc: 8 -> 4 bits
  transformer.h.3.mlp.c_proj: 8 -> 4 bits
  transformer.h.4.attn.c_attn: 8 -> 4 bits
  transformer.h.4.attn.c_proj: 8 -> 4 bits
  transformer.h.4.mlp.c_fc: 8 -> 4 bits
  transformer.h.4.mlp.c_proj: 8 -> 4 bits
  transformer.h.5.attn.c_attn: 8 -> 4 bits
  transformer.h.5.attn.c_proj: 8 -> 4 bits
  transformer.h.5.mlp.c_fc: 8 -

                                                 

🔧 Requantizing model to 4-bit...
  transformer.h.0.attn.c_attn.base: 4 -> 4 bits
  transformer.h.0.attn.c_proj: 4 -> 8 bits
  transformer.h.0.mlp.c_fc: 4 -> 8 bits
  transformer.h.0.mlp.c_proj: 4 -> 8 bits
  transformer.h.1.attn.c_attn.base: 4 -> 4 bits
  transformer.h.1.attn.c_proj: 4 -> 8 bits
  transformer.h.1.mlp.c_fc: 4 -> 8 bits
  transformer.h.1.mlp.c_proj: 4 -> 8 bits
  transformer.h.2.attn.c_attn: 4 -> 8 bits
  transformer.h.2.attn.c_proj: 4 -> 8 bits
  transformer.h.2.mlp.c_fc: 4 -> 8 bits
  transformer.h.2.mlp.c_proj: 4 -> 8 bits
  transformer.h.3.attn.c_attn: 4 -> 8 bits
  transformer.h.3.attn.c_proj: 4 -> 8 bits
  transformer.h.3.mlp.c_fc: 4 -> 8 bits
  transformer.h.3.mlp.c_proj: 4 -> 8 bits
  transformer.h.4.attn.c_attn: 4 -> 4 bits
  transformer.h.4.attn.c_proj: 4 -> 4 bits
  transformer.h.4.mlp.c_fc: 4 -> 4 bits
  transformer.h.4.mlp.c_proj: 4 -> 4 bits
  transformer.h.5.attn.c_attn: 4 -> 4 bits
  transformer.h.5.attn.c_proj: 4 -> 4 bits
  transformer.h.5.mlp.c_fc: 4 -

                                                 

🔧 Requantizing model to 4-bit...
  transformer.h.0.attn.c_attn.base: 4 -> 4 bits
  transformer.h.0.attn.c_proj: 8 -> 4 bits
  transformer.h.0.mlp.c_fc: 8 -> 4 bits
  transformer.h.0.mlp.c_proj: 8 -> 4 bits
  transformer.h.1.attn.c_attn.base: 4 -> 4 bits
  transformer.h.1.attn.c_proj: 8 -> 4 bits
  transformer.h.1.mlp.c_fc: 8 -> 4 bits
  transformer.h.1.mlp.c_proj: 8 -> 4 bits
  transformer.h.2.attn.c_attn: 8 -> 4 bits
  transformer.h.2.attn.c_proj: 8 -> 4 bits
  transformer.h.2.mlp.c_fc: 8 -> 4 bits
  transformer.h.2.mlp.c_proj: 8 -> 4 bits
  transformer.h.3.attn.c_attn: 8 -> 4 bits
  transformer.h.3.attn.c_proj: 8 -> 4 bits
  transformer.h.3.mlp.c_fc: 8 -> 4 bits
  transformer.h.3.mlp.c_proj: 8 -> 4 bits
  transformer.h.4.attn.c_attn: 4 -> 4 bits
  transformer.h.4.attn.c_proj: 4 -> 4 bits
  transformer.h.4.mlp.c_fc: 4 -> 4 bits
  transformer.h.4.mlp.c_proj: 4 -> 4 bits
  transformer.h.5.attn.c_attn: 4 -> 4 bits
  transformer.h.5.attn.c_proj: 4 -> 4 bits
  transformer.h.5.mlp.c_fc: 4 -

                                                 

🔧 Requantizing model to 4-bit...
  transformer.h.0.attn.c_attn.base: 4 -> 4 bits
  transformer.h.0.attn.c_proj: 4 -> 8 bits
  transformer.h.0.mlp.c_fc: 4 -> 8 bits
  transformer.h.0.mlp.c_proj: 4 -> 8 bits
  transformer.h.1.attn.c_attn.base: 4 -> 4 bits
  transformer.h.1.attn.c_proj: 4 -> 8 bits
  transformer.h.1.mlp.c_fc: 4 -> 8 bits
  transformer.h.1.mlp.c_proj: 4 -> 8 bits
  transformer.h.2.attn.c_attn: 4 -> 4 bits
  transformer.h.2.attn.c_proj: 4 -> 4 bits
  transformer.h.2.mlp.c_fc: 4 -> 4 bits
  transformer.h.2.mlp.c_proj: 4 -> 4 bits
  transformer.h.3.attn.c_attn: 4 -> 4 bits
  transformer.h.3.attn.c_proj: 4 -> 4 bits
  transformer.h.3.mlp.c_fc: 4 -> 4 bits
  transformer.h.3.mlp.c_proj: 4 -> 4 bits
  transformer.h.4.attn.c_attn: 4 -> 4 bits
  transformer.h.4.attn.c_proj: 4 -> 4 bits
  transformer.h.4.mlp.c_fc: 4 -> 4 bits
  transformer.h.4.mlp.c_proj: 4 -> 4 bits
  transformer.h.5.attn.c_attn: 4 -> 4 bits
  transformer.h.5.attn.c_proj: 4 -> 4 bits
  transformer.h.5.mlp.c_fc: 4 -

                                                 

🔧 Requantizing model to 4-bit...
  transformer.h.0.attn.c_attn.base: 4 -> 4 bits
  transformer.h.0.attn.c_proj: 8 -> 4 bits
  transformer.h.0.mlp.c_fc: 8 -> 4 bits
  transformer.h.0.mlp.c_proj: 8 -> 4 bits
  transformer.h.1.attn.c_attn.base: 4 -> 4 bits
  transformer.h.1.attn.c_proj: 8 -> 4 bits
  transformer.h.1.mlp.c_fc: 8 -> 4 bits
  transformer.h.1.mlp.c_proj: 8 -> 4 bits
  transformer.h.2.attn.c_attn: 4 -> 8 bits
  transformer.h.2.attn.c_proj: 4 -> 4 bits
  transformer.h.2.mlp.c_fc: 4 -> 4 bits
  transformer.h.2.mlp.c_proj: 4 -> 4 bits
  transformer.h.3.attn.c_attn: 4 -> 8 bits
  transformer.h.3.attn.c_proj: 4 -> 4 bits
  transformer.h.3.mlp.c_fc: 4 -> 4 bits
  transformer.h.3.mlp.c_proj: 4 -> 4 bits
  transformer.h.4.attn.c_attn: 4 -> 8 bits
  transformer.h.4.attn.c_proj: 4 -> 4 bits
  transformer.h.4.mlp.c_fc: 4 -> 4 bits
  transformer.h.4.mlp.c_proj: 4 -> 4 bits
  transformer.h.5.attn.c_attn: 4 -> 8 bits
  transformer.h.5.attn.c_proj: 4 -> 4 bits
  transformer.h.5.mlp.c_fc: 4 -

                                                 

🔧 Requantizing model to 4-bit...
  transformer.h.0.attn.c_attn.base: 4 -> 4 bits
  transformer.h.0.attn.c_proj: 4 -> 8 bits
  transformer.h.0.mlp.c_fc: 4 -> 4 bits
  transformer.h.0.mlp.c_proj: 4 -> 4 bits
  transformer.h.1.attn.c_attn.base: 4 -> 4 bits
  transformer.h.1.attn.c_proj: 4 -> 8 bits
  transformer.h.1.mlp.c_fc: 4 -> 4 bits
  transformer.h.1.mlp.c_proj: 4 -> 4 bits
  transformer.h.2.attn.c_attn: 8 -> 4 bits
  transformer.h.2.attn.c_proj: 4 -> 8 bits
  transformer.h.2.mlp.c_fc: 4 -> 4 bits
  transformer.h.2.mlp.c_proj: 4 -> 4 bits
  transformer.h.3.attn.c_attn: 8 -> 4 bits
  transformer.h.3.attn.c_proj: 4 -> 8 bits
  transformer.h.3.mlp.c_fc: 4 -> 4 bits
  transformer.h.3.mlp.c_proj: 4 -> 4 bits
  transformer.h.4.attn.c_attn: 8 -> 4 bits
  transformer.h.4.attn.c_proj: 4 -> 8 bits
  transformer.h.4.mlp.c_fc: 4 -> 4 bits
  transformer.h.4.mlp.c_proj: 4 -> 4 bits
  transformer.h.5.attn.c_attn: 8 -> 4 bits
  transformer.h.5.attn.c_proj: 4 -> 8 bits
  transformer.h.5.mlp.c_fc: 4 -

                                                 

🔧 Requantizing model to 4-bit...
  transformer.h.0.attn.c_attn.base: 4 -> 4 bits
  transformer.h.0.attn.c_proj: 8 -> 4 bits
  transformer.h.0.mlp.c_fc: 4 -> 4 bits
  transformer.h.0.mlp.c_proj: 4 -> 8 bits
  transformer.h.1.attn.c_attn.base: 4 -> 4 bits
  transformer.h.1.attn.c_proj: 8 -> 4 bits
  transformer.h.1.mlp.c_fc: 4 -> 4 bits
  transformer.h.1.mlp.c_proj: 4 -> 8 bits
  transformer.h.2.attn.c_attn: 4 -> 4 bits
  transformer.h.2.attn.c_proj: 8 -> 4 bits
  transformer.h.2.mlp.c_fc: 4 -> 4 bits
  transformer.h.2.mlp.c_proj: 4 -> 8 bits
  transformer.h.3.attn.c_attn: 4 -> 4 bits
  transformer.h.3.attn.c_proj: 8 -> 4 bits
  transformer.h.3.mlp.c_fc: 4 -> 4 bits
  transformer.h.3.mlp.c_proj: 4 -> 8 bits
  transformer.h.4.attn.c_attn: 4 -> 4 bits
  transformer.h.4.attn.c_proj: 8 -> 4 bits
  transformer.h.4.mlp.c_fc: 4 -> 4 bits
  transformer.h.4.mlp.c_proj: 4 -> 8 bits
  transformer.h.5.attn.c_attn: 4 -> 4 bits
  transformer.h.5.attn.c_proj: 8 -> 4 bits
  transformer.h.5.mlp.c_fc: 4 -

                                                 

🔧 Requantizing model to 4-bit...
  transformer.h.0.attn.c_attn.base: 4 -> 4 bits
  transformer.h.0.attn.c_proj: 4 -> 4 bits
  transformer.h.0.mlp.c_fc: 4 -> 4 bits
  transformer.h.0.mlp.c_proj: 8 -> 4 bits
  transformer.h.1.attn.c_attn.base: 4 -> 4 bits
  transformer.h.1.attn.c_proj: 4 -> 4 bits
  transformer.h.1.mlp.c_fc: 4 -> 4 bits
  transformer.h.1.mlp.c_proj: 8 -> 4 bits
  transformer.h.2.attn.c_attn: 4 -> 4 bits
  transformer.h.2.attn.c_proj: 4 -> 4 bits
  transformer.h.2.mlp.c_fc: 4 -> 4 bits
  transformer.h.2.mlp.c_proj: 8 -> 4 bits
  transformer.h.3.attn.c_attn: 4 -> 4 bits
  transformer.h.3.attn.c_proj: 4 -> 4 bits
  transformer.h.3.mlp.c_fc: 4 -> 4 bits
  transformer.h.3.mlp.c_proj: 8 -> 4 bits
  transformer.h.4.attn.c_attn: 4 -> 4 bits
  transformer.h.4.attn.c_proj: 4 -> 4 bits
  transformer.h.4.mlp.c_fc: 4 -> 4 bits
  transformer.h.4.mlp.c_proj: 8 -> 4 bits
  transformer.h.5.attn.c_attn: 4 -> 4 bits
  transformer.h.5.attn.c_proj: 4 -> 4 bits
  transformer.h.5.mlp.c_fc: 4 -

                                                 

🔧 Requantizing model to 8-bit...
  transformer.h.0.attn.c_attn.base: 4 -> 8 bits
  transformer.h.0.attn.c_proj: 4 -> 8 bits
  transformer.h.0.mlp.c_fc: 4 -> 8 bits
  transformer.h.0.mlp.c_proj: 4 -> 8 bits
  transformer.h.1.attn.c_attn.base: 4 -> 8 bits
  transformer.h.1.attn.c_proj: 4 -> 8 bits
  transformer.h.1.mlp.c_fc: 4 -> 8 bits
  transformer.h.1.mlp.c_proj: 4 -> 8 bits
  transformer.h.2.attn.c_attn: 4 -> 8 bits
  transformer.h.2.attn.c_proj: 4 -> 8 bits
  transformer.h.2.mlp.c_fc: 4 -> 8 bits
  transformer.h.2.mlp.c_proj: 4 -> 8 bits
  transformer.h.3.attn.c_attn: 4 -> 8 bits
  transformer.h.3.attn.c_proj: 4 -> 8 bits
  transformer.h.3.mlp.c_fc: 4 -> 8 bits
  transformer.h.3.mlp.c_proj: 4 -> 8 bits
  transformer.h.4.attn.c_attn: 4 -> 8 bits
  transformer.h.4.attn.c_proj: 4 -> 8 bits
  transformer.h.4.mlp.c_fc: 4 -> 8 bits
  transformer.h.4.mlp.c_proj: 4 -> 8 bits
  transformer.h.5.attn.c_attn: 4 -> 8 bits
  transformer.h.5.attn.c_proj: 4 -> 8 bits
  transformer.h.5.mlp.c_fc: 4 -

                                                 

🔧 Requantizing model to 4-bit...
  transformer.h.0.attn.c_attn.base: 8 -> 4 bits
  transformer.h.0.attn.c_proj: 8 -> 4 bits
  transformer.h.0.mlp.c_fc: 8 -> 4 bits
  transformer.h.0.mlp.c_proj: 8 -> 4 bits
  transformer.h.1.attn.c_attn.base: 8 -> 4 bits
  transformer.h.1.attn.c_proj: 8 -> 4 bits
  transformer.h.1.mlp.c_fc: 8 -> 4 bits
  transformer.h.1.mlp.c_proj: 8 -> 4 bits
  transformer.h.2.attn.c_attn: 8 -> 4 bits
  transformer.h.2.attn.c_proj: 8 -> 4 bits
  transformer.h.2.mlp.c_fc: 8 -> 4 bits
  transformer.h.2.mlp.c_proj: 8 -> 4 bits
  transformer.h.3.attn.c_attn: 8 -> 4 bits
  transformer.h.3.attn.c_proj: 8 -> 4 bits
  transformer.h.3.mlp.c_fc: 8 -> 4 bits
  transformer.h.3.mlp.c_proj: 8 -> 4 bits
  transformer.h.4.attn.c_attn: 8 -> 4 bits
  transformer.h.4.attn.c_proj: 8 -> 4 bits
  transformer.h.4.mlp.c_fc: 8 -> 4 bits
  transformer.h.4.mlp.c_proj: 8 -> 4 bits
  transformer.h.5.attn.c_attn: 8 -> 4 bits
  transformer.h.5.attn.c_proj: 8 -> 4 bits
  transformer.h.5.mlp.c_fc: 8 -

                                                 

            config  EM       F1  default_bits
           C1_all8 0.0 7.120680             8
            config 0.0 7.120680             8
  C10_mixed_budget 0.0 5.984884             4
   C4_back8_front4 0.0 5.849456             4
       C5_sandwich 0.0 5.746021             4
C8_mlpfc4_mlpproj8 0.0 5.479667             4
   C3_front8_back4 0.0 5.426505             4
     C6_qkv8_proj4 0.0 4.968294             4
           C2_all4 0.0 4.829026             4
 C9_layernorm_fp32 0.0 4.829026             4
       config_4bit 0.0 4.829026             4
     C7_qkv4_proj8 0.0 4.296233             4


