In [None]:
import torch
import torch.nn as nn
from model.modeling_llada import LLaDAModelLM
from transformers import AutoTokenizer
from quantization_calibration_dataset import LLaDACalibrationDataset
from torch.utils.data import DataLoader
import argparse
from duquant_utils import create_quant_args

  from .autonotebook import tqdm as notebook_tqdm


In [19]:
MODEL_PATH = "GSAI-ML/LLaDA-8B-Instruct"
device = "cuda"

model = LLaDAModelLM.from_pretrained(MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00, 707.20it/s]


In [20]:
user_args = {
    "nsamples": 128,
    "seqlen": 2048,
    "wbits": 8,
    "abits": 8,
    "alpha": 0.5, # for smoothquant
    "act_group_size": None,
    "smooth": True,
    "quant_method": "duquant",
    "symmetric": True,
    "group_size": None,
    "swc": 0.8,
    "lac": 0.9,
    "lwc": False,
    "block_size": 128,
    "max_rotation_step": 256,
    "permutation_times": 1,
    "batch_size": 1
}

# args = argparse.Namespace(**user_args)
args = create_quant_args(user_args)


In [21]:
act_scales = torch.load("act_scales/LLaDA-8B-Instruct.pt")        
dataset = LLaDACalibrationDataset(
    tokenizer=tokenizer,
    seq_len=args.seqlen,
    samples=args.nsamples,
    block_size=args.block_size,
)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

Building Calibration Buffer with Mask ID: 126336...
Concatenating and tokenizing dataset...
Total tokens in concatenated dataset: 2608998
Calibration Dataset Ready: 128 tensors.


In [22]:
import copy
import gc
import torch
from torch import nn
from duquant_utils import set_init_duquant_params_state, set_quant_state, smooth_and_let_inplace
from model.quantize.int_linear import QuantLinear
from model.int_llada_layer import LLaDaQuantLayer

CLIPMIN = 1e-5

def duquant(model: nn.Module, act_scales: dict, dataloader, args):
    layers = model.model.transformer.blocks
    use_cache = model.config.use_cache
    model.config.use_cache = False
    dtype = torch.bfloat16
    dev = "cuda" if torch.cuda.is_available() else "cpu"
    seqlen = args.seqlen
    pairs = {
        "q_proj":"qkv",
        "attn_out":"out",
        "up_proj":"fc1",
        "ff_out":"down",
    }

    inps = torch.zeros(
        (args.nsamples, seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    cache = {"i": 0}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

        def forward(self, inp, **kwargs):
            if len(inp.shape) == 3:
                 inps[cache["i"]] = inp[0] 
            else:
                 inps[cache["i"]] = inp
            cache["i"] += 1
            return self.module(inp, **kwargs)
    
    layers[0] = Catcher(layers[0])

    input_ids = []

    with torch.no_grad():
        for batch in dataloader:
            if cache["i"] >= args.nsamples:
                break
            try:
                input_ids.append(batch['input_ids'])
                model(batch['input_ids'].to(dev))

            except ValueError:
                pass
    
    layers[0] = layers[0].module
    layers[0] = layers[0].cpu()

    print(layers[0])
    duquant_parameters = {}

    torch.cuda.empty_cache()
    quant_inps = inps
    rotate_inps = copy.copy(inps).mean(dim=0)

    fp_inps = copy.deepcopy(inps)
    
    for i in range(len(layers)):
        print("Starting Layer, " + str(i))
        args.q_quant_params = copy.copy(args.act_quant_params)
        args.k_quant_params = copy.copy(args.act_quant_params)
        layer = layers[i]
        qlayer = LLaDaQuantLayer(layer, args)
        qlayer.set_quant_state(weight_quant=False, act_quant=True)

        for name, module in layer.named_modules():
            if isinstance(module, nn.Linear):
                weight_quant = QuantLinear(module, weight_quant_params=copy.copy(args.weight_quant_params), act_quant_params=copy.copy(args.act_quant_params))
                setattr(qlayer, name, weight_quant)

        qlayer.load_state_dict(layer.state_dict())
        qlayer.to(dev)

        set_init_duquant_params_state(qlayer, True)
        set_quant_state(qlayer, weight_quant=False, act_quant=True)

        qlayer.register_parameter("qkt_smooth_scale",torch.nn.Parameter(torch.ones(qlayer.q_proj.out_features,device=dev, dtype=dtype), requires_grad=False))
        for name, module in qlayer.named_modules():
            if isinstance(module, QuantLinear):
                for key in pairs.keys():
                    if key in name:
                        act = act_scales[f"model.transformer.blocks.{i}.{key}"].to(device=dev, dtype=dtype).clamp(min=CLIPMIN)
                        weight = module.weight.abs().max(dim=0)[0].clamp(min=CLIPMIN)
                        scale = (act.pow(args.alpha)/weight.to(act.device).pow(1-args.alpha)).clamp(min=CLIPMIN)

                        qlayer.register_parameter(f"{pairs[key]}_smooth_scale",torch.nn.Parameter(scale, requires_grad=False))

        qlayer.to(dtype=torch.bfloat16)

        try:
            with torch.no_grad():
                qlayer.qkt_smooth_scale.clamp_(min=0.5)
        except:
            pass
        smooth_and_let_inplace(qlayer, args)

        # perform duquant process
        set_init_duquant_params_state(qlayer, False)
        set_quant_state(qlayer, weight_quant=True, act_quant=True)
        with torch.no_grad():
            with torch.amp.autocast(device_type=dev):
                rotate_inps = qlayer(rotate_inps.unsqueeze(0))[0][0]
            qlayer.register_duquant_params()
            set_init_duquant_params_state(qlayer, True)

        qlayer.to(dtype=torch.bfloat16)
        with torch.no_grad():
            for name, module in qlayer.named_modules():
                if isinstance(module, QuantLinear):
                    module.weight = module.weight_quantizer(module.weight, return_no_quant=True)

        set_quant_state(qlayer, weight_quant=False, act_quant=True)
        layers[i] = qlayer.to("cpu")
        # i dont think this is necessary for loading
        # duquant_parameters[i] = duquant_state_dict(qlayer)

        del layer
        torch.cuda.empty_cache()

    model.model.transformer.embed_tokens = model.model.transformer.wte.to('cpu')
    del inps
    del quant_inps
    del fp_inps
    del rotate_inps
    
    torch.cuda.empty_cache()
    gc.collect()                    
    model.config.use_cache = use_cache
    
    return model

In [23]:
model = duquant(model, act_scales, dataloader, args)

LLaDALlamaBlock(
  (dropout): Dropout(p=0.0, inplace=False)
  (act): SiLU()
  (attn_out): Linear(in_features=4096, out_features=4096, bias=False)
  (ff_out): Linear(in_features=12288, out_features=4096, bias=False)
  (rotary_emb): RotaryEmbedding()
  (attn_norm): RMSLayerNorm()
  (ff_norm): RMSLayerNorm()
  (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (ff_proj): Linear(in_features=4096, out_features=12288, bias=False)
  (up_proj): Linear(in_features=4096, out_features=12288, bias=False)
)
Starting Layer, 0
Starting Layer, 1
Starting Layer, 2
Starting Layer, 3
Starting Layer, 4
Starting Layer, 5
Starting Layer, 6
Starting Layer, 7
Starting Layer, 8
Starting Layer, 9
Starting Layer, 10
Starting Layer, 11
Starting Layer, 12
Starting Layer, 13
Starting Layer, 14
Starting Layer, 15
Starting Layer, 16
Starting Layer, 17
Starting Layer,

In [26]:
model.to("cuda")
model.eval()

LLaDAModelLM(
  (model): LLaDAModel(
    (transformer): ModuleDict(
      (wte): Embedding(126464, 4096)
      (emb_drop): Dropout(p=0.0, inplace=False)
      (ln_f): RMSLayerNorm()
      (blocks): ModuleList(
        (0-31): 32 x LLaDaQuantLayer(
          (dropout): Dropout(p=0.0, inplace=False)
          (act): SiLU()
          (attn_out): QuantLinear(
            (weight_quantizer): UniformAffineQuantizer(
              (sigmoid): Sigmoid()
            )
            (act_quantizer): UniformAffineQuantizer(
              (sigmoid): Sigmoid()
            )
          )
          (ff_out): QuantLinear(
            (weight_quantizer): UniformAffineQuantizer(
              (sigmoid): Sigmoid()
            )
            (act_quantizer): UniformAffineQuantizer(
              (sigmoid): Sigmoid()
            )
          )
          (rotary_emb): RotaryEmbedding()
          (attn_norm): RMSLayerNorm()
          (ff_norm): RMSLayerNorm()
          (q_proj): QuantLinear(
            (weight_qu

In [27]:
from generate import generate

user_input = input("Enter your question: ")

m = [{"role": "user", "content": user_input}]
user_input = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
input_ids = tokenizer(user_input)['input_ids']
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)

out, nfe = generate(model, input_ids, steps=128, gen_length=128, block_length=args.block_size, temperature=0., remasking='low_confidence', threshold=0.9)
answer = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
print(f"Bot's reply: {answer}")

Bot's reply: Hello! How can I assist you today?


In [30]:
original_model = LLaDAModelLM.from_pretrained(MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16)

Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00, 213.11it/s]


In [33]:
quantized_state_dict = model.state_dict()
original_state_dict = original_model.state_dict()

for key in quantized_state_dict.keys():
    quantized_state_dict[key] = quantized_state_dict[key].to("cpu")

for key in original_state_dict.keys():
    original_state_dict[key] = original_state_dict[key].to("cpu")

common_keys = set(quantized_state_dict.keys()) & set(original_state_dict.keys())
for key in common_keys:
    quantized_weight = quantized_state_dict[key]
    original_weight = original_state_dict[key]
    if torch.equal(quantized_weight, original_weight):
        print(f"Equality found in {key}")
        print(f"Quantized weight shape: {quantized_weight.shape}")
        print(f"Original weight shape: {original_weight.shape}")

Equality found in model.transformer.ff_out.weight
Quantized weight shape: torch.Size([126464, 4096])
Original weight shape: torch.Size([126464, 4096])
Equality found in model.transformer.wte.weight
Quantized weight shape: torch.Size([126464, 4096])
Original weight shape: torch.Size([126464, 4096])
Equality found in model.transformer.ln_f.weight
Quantized weight shape: torch.Size([4096])
Original weight shape: torch.Size([4096])


In [13]:
import torch
import torch.nn.functional as F

model_path_1 = "models/quantized_model.pth"
model_path_2 = "models/my_quantized_4_4.pth"

# Load state dicts (adjust if using safetensors or specific quant loaders)
sd1 = torch.load(model_path_1, map_location="cpu")
sd2 = torch.load(model_path_2, map_location="cpu")

# Ensure keys match
keys1, keys2 = set(sd1.keys()), set(sd2.keys())
common_keys = keys1.intersection(keys2)
print(keys1 ^ keys2)

if len(common_keys) == 0:
    print("Error: No matching layers found. Check naming conventions.")

results = []

for key in common_keys:
    tensor1 = sd1[key].float().flatten()
    tensor2 = sd2[key].float().flatten()
    
    # Check for shape mismatches (e.g., if one fused Q/K/V and the other didn't)
    if tensor1.shape != tensor2.shape:
        print(f"Shape mismatch at {key}: {tensor1.shape} vs {tensor2.shape}")
        continue

    # Calculate metrics
    mse = F.mse_loss(tensor1, tensor2).item()
    cos_sim = F.cosine_similarity(tensor1.unsqueeze(0), tensor2.unsqueeze(0)).item()
    max_diff = torch.max(torch.abs(tensor1 - tensor2)).item()
    
    results.append({
        "layer": key,
        "mse": mse,
        "cos_sim": cos_sim,
        "max_diff": max_diff
    })

# Sort by MSE to find the most divergent layers
results.sort(key=lambda x: x['mse'], reverse=True)

print(f"{'Layer Name':<45} | {'MSE':<10} | {'Cos Sim':<10} | {'Max Diff':<10}")
print("-" * 85)

# Print the top 15 most divergent layers
for res in results:
    print(f"{res['layer']} | {res['mse']:<10.4f} | {res['cos_sim']:<10.4f} | {res['max_diff']:<10.4f}")

# Example usage:


{'model.transformer.blocks.9.ff_proj.bias', 'model.transformer.blocks.25.ori_layer.k_proj.weight', 'model.transformer.blocks.8.up_proj.bias', 'model.transformer.blocks.5.attn_out.bias', 'model.transformer.blocks.28.ff_norm.bias', 'model.transformer.blocks.17.ori_layer.attn_norm.weight', 'model.transformer.blocks.29.ori_layer.ff_out.weight', 'model.transformer.blocks.28.down_smooth_shift', 'model.transformer.blocks.8.attn_out.bias', 'model.transformer.blocks.2.q_proj.bias', 'model.transformer.blocks.0.fc1_smooth_shift', 'model.transformer.blocks.9.ori_layer.v_proj.weight', 'model.transformer.blocks.22.q_proj.bias', 'model.transformer.blocks.1.qkv_smooth_shift', 'model.transformer.blocks.3.ori_layer.q_proj.weight', 'model.transformer.blocks.20.ori_layer.q_proj.weight', 'model.transformer.blocks.19.ori_layer.attn_norm.weight', 'model.transformer.blocks.28.fc1_smooth_shift', 'model.transformer.blocks.19.ff_norm.bias', 'model.transformer.blocks.10.ori_layer.ff_out.weight', 'model.transforme

In [9]:
sd1["model.transformer.blocks.1.ff_out.weight_quantizer.permutation_list"]

tensor([[11499,  7132,  3515,  ...,   501,  8872,  3334]])

In [10]:
sd2["model.transformer.blocks.1.ff_out.weight_quantizer.permutation_list"]

tensor([[11482,  1776,  4833,  ...,  1676, 11236,  2820]])

In [18]:
sd1.keys() - sd2.keys()

{'model.transformer.blocks.0.attn_norm.bias',
 'model.transformer.blocks.0.attn_out.bias',
 'model.transformer.blocks.0.down_smooth_shift',
 'model.transformer.blocks.0.fc1_smooth_shift',
 'model.transformer.blocks.0.ff_norm.bias',
 'model.transformer.blocks.0.ff_out.bias',
 'model.transformer.blocks.0.ff_proj.bias',
 'model.transformer.blocks.0.k_proj.bias',
 'model.transformer.blocks.0.ori_layer.attn_norm.weight',
 'model.transformer.blocks.0.ori_layer.attn_out.weight',
 'model.transformer.blocks.0.ori_layer.ff_norm.weight',
 'model.transformer.blocks.0.ori_layer.ff_out.weight',
 'model.transformer.blocks.0.ori_layer.ff_proj.weight',
 'model.transformer.blocks.0.ori_layer.k_proj.weight',
 'model.transformer.blocks.0.ori_layer.q_proj.weight',
 'model.transformer.blocks.0.ori_layer.up_proj.weight',
 'model.transformer.blocks.0.ori_layer.v_proj.weight',
 'model.transformer.blocks.0.out_smooth_shift',
 'model.transformer.blocks.0.q_proj.bias',
 'model.transformer.blocks.0.qkv_smooth_shif