In [1]:
import torch
import torch.nn as nn

from duquant_utils import *
from model.modeling_llada import LLaDAModelLM
from transformers import AutoModelForCausalLM, AutoTokenizer
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_PATH = "GSAI-ML/LLaDA-8B-Instruct"
QUANT_ARGS_PATH = "model/quantize/quant_args.json"
QUANTIZED_WEIGHTS_PATH = "models/quantized_model.pth"

In [3]:
model = LLaDAModelLM.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="cpu", dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
quant_config = json.load(open(QUANT_ARGS_PATH))
model_weights = torch.load(QUANTIZED_WEIGHTS_PATH, map_location="cpu")

quant_args = create_quant_args(quant_config)

Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00, 687.84it/s]


In [4]:
original_model = LLaDAModelLM.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="cuda", dtype=torch.bfloat16)

original_model.eval()

2.9.1+cu130


    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)
    
Loading checkpoint shards: 100%|██████████| 6/6 [00:33<00:00,  5.58s/it]


LLaDAModelLM(
  (model): LLaDAModel(
    (transformer): ModuleDict(
      (wte): Embedding(126464, 4096)
      (emb_drop): Dropout(p=0.0, inplace=False)
      (ln_f): RMSLayerNorm()
      (blocks): ModuleList(
        (0-31): 32 x LLaDALlamaBlock(
          (dropout): Dropout(p=0.0, inplace=False)
          (act): SiLU()
          (attn_out): Linear(in_features=4096, out_features=4096, bias=False)
          (ff_out): Linear(in_features=12288, out_features=4096, bias=False)
          (rotary_emb): RotaryEmbedding()
          (attn_norm): RMSLayerNorm()
          (ff_norm): RMSLayerNorm()
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (ff_proj): Linear(in_features=4096, out_features=12288, bias=False)
          (up_proj): Linear(in_features=4096, out_features=12288, bias=False)
        )
      )
    

In [5]:
#loading duquant model:

replace_llada_blocks(model, quant_args, device="cpu")

replace_linear_layers(model, quant_args, model_weights)

print("Loading Quantized Model...")
missing_keys, unexpected_keys = model.load_state_dict(model_weights, strict=False)
model.to("cuda")
model.eval()


Replacing block model.transformer.blocks.0
Replacing block model.transformer.blocks.1
Replacing block model.transformer.blocks.2
Replacing block model.transformer.blocks.3
Replacing block model.transformer.blocks.4
Replacing block model.transformer.blocks.5
Replacing block model.transformer.blocks.6
Replacing block model.transformer.blocks.7
Replacing block model.transformer.blocks.8
Replacing block model.transformer.blocks.9
Replacing block model.transformer.blocks.10
Replacing block model.transformer.blocks.11
Replacing block model.transformer.blocks.12
Replacing block model.transformer.blocks.13
Replacing block model.transformer.blocks.14
Replacing block model.transformer.blocks.15
Replacing block model.transformer.blocks.16
Replacing block model.transformer.blocks.17
Replacing block model.transformer.blocks.18
Replacing block model.transformer.blocks.19
Replacing block model.transformer.blocks.20
Replacing block model.transformer.blocks.21
Replacing block model.transformer.blocks.2

LLaDAModelLM(
  (model): LLaDAModel(
    (transformer): ModuleDict(
      (wte): Embedding(126464, 4096)
      (emb_drop): Dropout(p=0.0, inplace=False)
      (ln_f): RMSLayerNorm()
      (blocks): ModuleList(
        (0-31): 32 x LLaDaQuantLayer(
          (dropout): Dropout(p=0.0, inplace=False)
          (act): SiLU()
          (attn_out): QuantLinear(
            (weight_quantizer): UniformAffineQuantizer(
              (sigmoid): Sigmoid()
            )
            (act_quantizer): UniformAffineQuantizer(
              (sigmoid): Sigmoid()
            )
          )
          (ff_out): QuantLinear(
            (weight_quantizer): UniformAffineQuantizer(
              (sigmoid): Sigmoid()
            )
            (act_quantizer): UniformAffineQuantizer(
              (sigmoid): Sigmoid()
            )
          )
          (rotary_emb): RotaryEmbedding()
          (attn_norm): RMSLayerNorm()
          (ff_norm): RMSLayerNorm()
          (q_proj): QuantLinear(
            (weight_qu

In [6]:
from model.quantize.int_linear import QuantLinear

def create_activation_hook(storage, name):
    def hook(module, input, output):
        x = input[0] if isinstance(input, tuple) else input
        storage[name] = x.detach().cpu().clone()
    return hook

def register_activation_hooks(model, activations):
    hooks = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) or isinstance(module, QuantLinear):
            hook = module.register_forward_hook(create_activation_hook(activations, name))
            hooks.append(hook)
    return hooks

def remove_hooks(hooks):
    for hook in hooks:
        hook.remove()
    print(f"Removed {len(hooks)} hooks")

In [7]:
# Create test input using a sample prompt
# For LLaDA, we also apply some masking to simulate the diffusion process

test_prompt = """The quick brown fox jumps over the lazy dog. This is a test sentence 
to analyze how SmoothQuant affects the activation distributions in the LLaDA model.
Machine learning models often have outlier activations that make quantization difficult.
SmoothQuant addresses this by migrating quantization difficulty from activations to weights."""

# Tokenize the input
encoded = tokenizer(
    test_prompt,
    return_tensors='pt',
    padding='max_length',
    max_length=1024,
    truncation=True,
)
test_input_ids = encoded.input_ids.to("cuda")

# Get the mask token id for LLaDA
if hasattr(tokenizer, "mask_token_id") and tokenizer.mask_token_id is not None:
    mask_token_id = tokenizer.mask_token_id
else:
    mask_token_id = 126336  # Default for LLaDA

# Apply partial masking to simulate LLaDA's input (50% masking)
mask_ratio = 0.5
num_tokens = test_input_ids.shape[1]
num_mask = int(num_tokens * mask_ratio)
mask_indices = torch.randperm(num_tokens)[:num_mask]

masked_input_ids = test_input_ids.clone()
masked_input_ids[0, mask_indices] = mask_token_id

print(f"Test input shape: {masked_input_ids.shape}")
print(f"Number of masked tokens: {num_mask} / {num_tokens} ({mask_ratio*100:.0f}%)")

Test input shape: torch.Size([1, 1024])
Number of masked tokens: 512 / 1024 (50%)


In [8]:
original_activations = {}

original_hooks = register_activation_hooks(original_model, original_activations)

with torch.no_grad():
    _ = original_model(input_ids=masked_input_ids)

remove_hooks(original_hooks)

Removed 225 hooks


In [9]:
original_activations.keys()

dict_keys(['model.transformer.blocks.0.q_proj', 'model.transformer.blocks.0.k_proj', 'model.transformer.blocks.0.v_proj', 'model.transformer.blocks.0.attn_out', 'model.transformer.blocks.0.ff_proj', 'model.transformer.blocks.0.up_proj', 'model.transformer.blocks.0.ff_out', 'model.transformer.blocks.1.q_proj', 'model.transformer.blocks.1.k_proj', 'model.transformer.blocks.1.v_proj', 'model.transformer.blocks.1.attn_out', 'model.transformer.blocks.1.ff_proj', 'model.transformer.blocks.1.up_proj', 'model.transformer.blocks.1.ff_out', 'model.transformer.blocks.2.q_proj', 'model.transformer.blocks.2.k_proj', 'model.transformer.blocks.2.v_proj', 'model.transformer.blocks.2.attn_out', 'model.transformer.blocks.2.ff_proj', 'model.transformer.blocks.2.up_proj', 'model.transformer.blocks.2.ff_out', 'model.transformer.blocks.3.q_proj', 'model.transformer.blocks.3.k_proj', 'model.transformer.blocks.3.v_proj', 'model.transformer.blocks.3.attn_out', 'model.transformer.blocks.3.ff_proj', 'model.trans

In [10]:
import re
joint_layer_activations = {}

for key in original_activations:
    if not re.search(r'\d+', key) or key.find('ff_out') != -1:
        continue
    layer = key.split('.')[3]
    if layer not in joint_layer_activations:
        joint_layer_activations[layer] = original_activations[key]
    else:
        joint_layer_activations[layer] = torch.cat([joint_layer_activations[layer], original_activations[key]], dim=1)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

