# Quantization + Rate-Distortion: Trivial Methods

**Objective:** Explore simple weight-only quantization methods for Llama models and measure empirical rate-distortion curves vs the Shannon bound.

**Shannon Bound (Gaussian):**  
R(D) = 1/2 * log2(sigma^2 / D) for D <= sigma^2

This notebook covers:
- Weight extraction (MLP + attention + embeddings)
- Simple quantization sweeps (per-tensor vs per-channel)
- Rate-distortion comparisons (uniform, clipped, Lloyd-Max, group)
- Vector quantization (k-means, d = 2, 4, 8)
- Product quantization (multi-codebook variant)

In [None]:
# === AUTHENTICATION (required) ===
from huggingface_hub import login

# Paste your token directly as a string argument
login(token="...")

# After running successfully, DELETE this cell or clear the token string

In [None]:
# === Install dependencies ===
# Uncomment the lines below to install required packages

%pip install -q torch torchvision
%pip install -q numpy scipy scikit-learn matplotlib pandas transformers accelerate

In [None]:
!nvidia-smi

# Initial Quantization Experiments
I did this just to get a sense of how to quantize weights in transformers for trivial methods.

In [None]:
# Core imports
import math
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM

try:
    from sklearn.cluster import MiniBatchKMeans
    SKLEARN_OK = True
except Exception:
    SKLEARN_OK = False

try:
    from scipy.cluster.vq import kmeans, vq
    SCIPY_VQ_OK = True
except Exception:
    SCIPY_VQ_OK = False

In [None]:
# Load model (Llama 3.2 1B by default)
model_name = "meta-llama/Llama-3.2-1B"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float32,
    device_map="auto",
)
model.eval()

print(f"Loaded: {model_name}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Extract weight groups (MLP, attention, embeddings)

def extract_weight_groups(model):
    mlp = {}
    attn = {"q": {}, "k": {}, "v": {}, "o": {}}

    for name, param in model.named_parameters():
        if not name.endswith(".weight"):
            continue
        if "model.layers." not in name:
            continue

        parts = name.split(".")
        try:
            layer_num = int(parts[2])
        except Exception:
            continue

        w = param.detach().cpu().float().numpy()

        if "mlp.down_proj.weight" in name:
            mlp[layer_num] = w
        elif "self_attn.q_proj.weight" in name:
            attn["q"][layer_num] = w
        elif "self_attn.k_proj.weight" in name:
            attn["k"][layer_num] = w
        elif "self_attn.v_proj.weight" in name:
            attn["v"][layer_num] = w
        elif "self_attn.o_proj.weight" in name:
            attn["o"][layer_num] = w

    embeddings = {}
    try:
        embeddings["token"] = model.get_input_embeddings().weight.detach().cpu().float().numpy()
    except Exception:
        pass

    if hasattr(model, "lm_head") and model.lm_head is not None:
        try:
            embeddings["lm_head"] = model.lm_head.weight.detach().cpu().float().numpy()
        except Exception:
            pass

    return mlp, attn, embeddings

mlp_weights, attn_weights, embedding_weights = extract_weight_groups(model)

print(f"MLP down-proj layers: {len(mlp_weights)}")
print("Attention proj layers:")
for k in ["q", "k", "v", "o"]:
    print(f"  {k}: {len(attn_weights[k])}")
print("Embeddings:", list(embedding_weights.keys()))

In [None]:
# Quantization utilities (per-tensor vs per-channel)

def symmetric_quantize(x, bits, per_channel=False, axis=1):
    """
    Symmetric uniform quantization to signed integers.
    per_channel=True uses per-output-channel scaling (axis=1 for [out, in]).
    Returns dequantized array and scale(s).
    """
    x = x.astype(np.float32)
    qmax = (2 ** (bits - 1)) - 1

    if per_channel:
        max_abs = np.max(np.abs(x), axis=axis, keepdims=True) + 1e-12
        scale = max_abs / qmax
        q = np.clip(np.round(x / scale), -qmax, qmax)
        dq = q * scale
    else:
        max_abs = float(np.max(np.abs(x)))
        scale = max_abs / qmax if max_abs > 0 else 1.0
        q = np.clip(np.round(x / scale), -qmax, qmax)
        dq = q * scale

    return dq.astype(np.float32), scale


def quant_metrics(x, xq):
    x = x.astype(np.float32)
    xq = xq.astype(np.float32)
    mse = float(np.mean((x - xq) ** 2))
    denom = float(np.mean(x ** 2))
    sqnr_db = 10 * math.log10(denom / mse) if mse > 0 and denom > 0 else float("inf")
    max_abs = float(np.max(np.abs(x)))
    max_err = float(np.max(np.abs(x - xq)))
    return {
        "mse": mse,
        "sqnr_db": sqnr_db,
        "max_abs": max_abs,
        "max_err": max_err,
        "rel_max_err": max_err / max_abs if max_abs > 0 else 0.0,
    }

In [None]:
# Quick bit sweep (per-tensor vs per-channel)
layer = 8
W = mlp_weights[layer]

bits_list = [2, 3, 4, 5, 6, 8]

results = {"per_tensor": [], "per_channel": []}

for bits in bits_list:
    dq_t, _ = symmetric_quantize(W, bits, per_channel=False)
    dq_c, _ = symmetric_quantize(W, bits, per_channel=True, axis=1)

    results["per_tensor"].append(quant_metrics(W, dq_t))
    results["per_channel"].append(quant_metrics(W, dq_c))

print(f"Layer {layer} quantization summary")
print(f"{'Bits':>6} {'MSE(t)':>12} {'MSE(c)':>12} {'SQNR(t)':>10} {'SQNR(c)':>10}")
print("-" * 60)
for i, bits in enumerate(bits_list):
    mt = results["per_tensor"][i]
    mc = results["per_channel"][i]
    print(f"{bits:>6} {mt['mse']:>12.2e} {mc['mse']:>12.2e} {mt['sqnr_db']:>10.2f} {mc['sqnr_db']:>10.2f}")

plt.figure(figsize=(7, 4))
plt.plot(bits_list, [r['mse'] for r in results["per_tensor"]], marker='o', label='per-tensor')
plt.plot(bits_list, [r['mse'] for r in results["per_channel"]], marker='o', label='per-channel')
plt.yscale('log')
plt.xlabel('Bits')
plt.ylabel('MSE (log scale)')
plt.title(f'Layer {layer}: Quantization MSE')
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# Layerwise sensitivity at a fixed bit-width (per-channel)
bits = 4

layer_rows = []
for layer_num, W in sorted(mlp_weights.items()):
    dq, _ = symmetric_quantize(W, bits, per_channel=True, axis=1)
    m = quant_metrics(W, dq)
    layer_rows.append((layer_num, m["mse"], m["sqnr_db"], m["rel_max_err"]))

layer_rows.sort(key=lambda t: t[1], reverse=True)

print(f"Top 8 worst MLP layers at {bits}-bit per-channel")
print(f"{'Layer':>6} {'MSE':>12} {'SQNR(dB)':>10} {'RelMaxErr':>10}")
print("-" * 46)
for layer_num, mse, sqnr_db, rel_max_err in layer_rows[:8]:
    print(f"{layer_num:>6} {mse:>12.2e} {sqnr_db:>10.2f} {rel_max_err:>10.3f}")

layers = [r[0] for r in layer_rows]
mse_vals = [r[1] for r in layer_rows]

plt.figure(figsize=(10, 3))
plt.plot(layers, mse_vals, marker='o', linewidth=1)
plt.yscale('log')
plt.xlabel('Layer')
plt.ylabel('MSE (log scale)')
plt.title(f'MLP down-proj quantization error ({bits}-bit, per-channel)')
plt.tight_layout()
plt.show()

In [None]:
# Shannon bound and gap utilities

def shannon_distortion(sigma_sq, rate_bits):
    return sigma_sq / (4 ** rate_bits)


def gap_bits(mse, d_shannon):
    if mse <= 0 or d_shannon <= 0:
        return 0.0
    return 0.5 * math.log2(mse / d_shannon)


def interpolate_rate_for_distortion(rates, mses, target_d):
    """
    Estimate R(D) by linear interpolation in log-D space.
    rates: list of bits/weight
    mses: list of distortions
    target_d: target distortion
    """
    rates = np.asarray(rates, dtype=np.float32)
    mses = np.asarray(mses, dtype=np.float32)

    if target_d <= 0:
        return None

    order = np.argsort(rates)
    rates = rates[order]
    mses = mses[order]

    logd = np.log(mses + 1e-30)
    logt = np.log(target_d)

    for i in range(len(rates) - 1):
        if (logd[i] >= logt and logd[i + 1] <= logt) or (logd[i] <= logt and logd[i + 1] >= logt):
            t = (logt - logd[i]) / (logd[i + 1] - logd[i] + 1e-30)
            return float(rates[i] + t * (rates[i + 1] - rates[i]))

    return None

In [None]:
# Scalar + group quantization methods

def quantize_uniform_asymmetric(weights, bits):
    flat = weights.flatten().astype(np.float32)
    levels = 2 ** bits
    w_min, w_max = float(flat.min()), float(flat.max())
    if w_max == w_min:
        return flat, 0.0
    scale = (w_max - w_min) / (levels - 1)
    quantized = np.round((flat - w_min) / scale) * scale + w_min
    mse = float(np.mean((flat - quantized) ** 2))
    return quantized, mse


def quantize_uniform_symmetric(weights, bits):
    flat = weights.flatten().astype(np.float32)
    levels = 2 ** bits
    max_abs = float(np.max(np.abs(flat)))
    if max_abs == 0:
        return flat, 0.0
    scale = (2 * max_abs) / (levels - 1)
    quantized = np.round(flat / scale) * scale
    mse = float(np.mean((flat - quantized) ** 2))
    return quantized, mse


def quantize_symmetric_clipped(weights, bits, clip_sigma=3.0):
    flat = weights.flatten().astype(np.float32)
    levels = 2 ** bits

    mu, sigma = float(np.mean(flat)), float(np.std(flat))
    clip_val = clip_sigma * sigma
    flat_clipped = np.clip(flat, -clip_val, clip_val)

    max_abs = float(np.max(np.abs(flat_clipped)))
    if max_abs == 0:
        return flat, 0.0

    scale = (2 * max_abs) / (levels - 1)
    quantized = np.round(flat_clipped / scale) * scale
    mse = float(np.mean((flat - quantized) ** 2))
    return quantized, mse


def quantize_group(weights, bits, group_size=128):
    flat = weights.flatten().astype(np.float32)
    n = len(flat)

    pad_size = (group_size - n % group_size) % group_size
    if pad_size > 0:
        flat_padded = np.concatenate([flat, np.zeros(pad_size, dtype=np.float32)])
    else:
        flat_padded = flat

    groups = flat_padded.reshape(-1, group_size)
    levels = 2 ** bits
    quantized_groups = np.zeros_like(groups)

    for i in range(groups.shape[0]):
        group = groups[i]
        max_abs = float(np.max(np.abs(group)))
        if max_abs == 0:
            quantized_groups[i] = group
            continue
        scale = (2 * max_abs) / (levels - 1)
        quantized_groups[i] = np.round(group / scale) * scale

    quantized = quantized_groups.flatten()[:n]
    mse = float(np.mean((flat - quantized) ** 2))
    effective_bits = bits + 16 / group_size  # FP16 scale per group
    return quantized, mse, effective_bits


def quantize_lloyd_max(weights, bits, max_iter=20):
    flat = weights.flatten().astype(np.float32)
    levels = 2 ** bits

    percentiles = np.linspace(0, 100, levels + 2)[1:-1]
    centroids = np.percentile(flat, percentiles)

    for _ in range(max_iter):
        boundaries = (centroids[:-1] + centroids[1:]) / 2
        boundaries = np.concatenate([[-np.inf], boundaries, [np.inf]])
        assignments = np.digitize(flat, boundaries[1:-1])

        new_centroids = np.zeros(levels, dtype=np.float32)
        for i in range(levels):
            mask = assignments == i
            if np.any(mask):
                new_centroids[i] = np.mean(flat[mask])
            else:
                new_centroids[i] = centroids[i]

        if np.allclose(centroids, new_centroids, rtol=1e-6):
            break
        centroids = new_centroids

    boundaries = (centroids[:-1] + centroids[1:]) / 2
    assignments = np.digitize(flat, boundaries)
    assignments = np.clip(assignments, 0, levels - 1)
    quantized = centroids[assignments]

    mse = float(np.mean((flat - quantized) ** 2))
    return quantized, mse

In [None]:
# k-means helpers (vector + product quantization)

def _kmeans_fit(vectors, n_centroids, seed=0):
    if SKLEARN_OK:
        km = MiniBatchKMeans(
            n_clusters=n_centroids,
            random_state=seed,
            batch_size=4096,
            n_init=3,
            max_iter=50,
        )
        km.fit(vectors)
        return km

    if SCIPY_VQ_OK:
        centroids, _ = kmeans(vectors, n_centroids, iter=20)
        return centroids

    raise RuntimeError("No k-means backend found. Install scikit-learn or scipy.")


def _kmeans_assign(vectors, model_or_centroids):
    if SKLEARN_OK and hasattr(model_or_centroids, "predict"):
        return model_or_centroids.predict(vectors), model_or_centroids.cluster_centers_

    if SCIPY_VQ_OK:
        assignments, _ = vq(vectors, model_or_centroids)
        return assignments, model_or_centroids

    raise RuntimeError("No k-means backend found. Install scikit-learn or scipy.")


def quantize_kmeans_vq(weights, bits, dim=2, sample_vectors=10000, eval_vectors=10000, seed=0):
    flat = weights.flatten().astype(np.float32)
    n = len(flat)

    pad_size = (dim - n % dim) % dim
    if pad_size > 0:
        flat_padded = np.concatenate([flat, np.zeros(pad_size, dtype=np.float32)])
    else:
        flat_padded = flat

    vectors = flat_padded.reshape(-1, dim)
    rng = np.random.default_rng(seed)

    if sample_vectors is not None and len(vectors) > sample_vectors:
        idx = rng.choice(len(vectors), size=sample_vectors, replace=False)
        train_vectors = vectors[idx]
    else:
        train_vectors = vectors

    n_centroids = 2 ** bits
    model = _kmeans_fit(train_vectors, n_centroids, seed=seed)

    if eval_vectors is not None and len(vectors) > eval_vectors:
        idx = rng.choice(len(vectors), size=eval_vectors, replace=False)
        eval_vectors_subset = vectors[idx]
    else:
        eval_vectors_subset = vectors

    assignments, centroids = _kmeans_assign(eval_vectors_subset, model)
    quantized_vectors = centroids[assignments]

    mse = float(np.mean((eval_vectors_subset - quantized_vectors) ** 2))
    effective_bits = bits / dim
    return quantized_vectors, mse, effective_bits


def quantize_product_quantization(weights, bits, dim=4, codebooks=4, sample_vectors=10000, eval_vectors=10000, seed=0):
    flat = weights.flatten().astype(np.float32)
    n = len(flat)

    pad_size = (dim - n % dim) % dim
    if pad_size > 0:
        flat_padded = np.concatenate([flat, np.zeros(pad_size, dtype=np.float32)])
    else:
        flat_padded = flat

    vectors = flat_padded.reshape(-1, dim)
    rng = np.random.default_rng(seed)

    group_ids = np.arange(len(vectors)) % codebooks

    total_se = 0.0
    total_n = 0

    for g in range(codebooks):
        group_vecs = vectors[group_ids == g]
        if len(group_vecs) == 0:
            continue

        if sample_vectors is not None and len(group_vecs) > sample_vectors:
            idx = rng.choice(len(group_vecs), size=sample_vectors, replace=False)
            train_vecs = group_vecs[idx]
        else:
            train_vecs = group_vecs

        n_centroids = 2 ** bits
        model = _kmeans_fit(train_vecs, n_centroids, seed=seed + g)

        if eval_vectors is not None and len(group_vecs) > eval_vectors:
            idx = rng.choice(len(group_vecs), size=eval_vectors, replace=False)
            eval_vecs = group_vecs[idx]
        else:
            eval_vecs = group_vecs

        assignments, centroids = _kmeans_assign(eval_vecs, model)
        quantized = centroids[assignments]

        total_se += float(np.sum((eval_vecs - quantized) ** 2))
        total_n += eval_vecs.size

    mse = total_se / total_n if total_n > 0 else 0.0
    effective_bits = bits / dim
    return mse, effective_bits

In [None]:
# Scalar quantizers: rate-distortion curves

def measure_scalar_quantizers(weights, bits_list=range(1, 9)):
    flat = weights.flatten().astype(np.float32)
    sigma_sq = float(np.var(flat))

    results = {
        "bits": [],
        "shannon_d": [],
        "uniform_asym": [],
        "uniform_sym": [],
        "clipped_3sigma": [],
        "group_128": [],
        "group_32": [],
        "lloyd_max": [],
        "group_128_rate": [],
        "group_32_rate": [],
    }

    for bits in bits_list:
        print(f"  Computing {bits}-bit quantization...")
        d_shannon = shannon_distortion(sigma_sq, bits)

        results["bits"].append(bits)
        results["shannon_d"].append(d_shannon)

        _, mse = quantize_uniform_asymmetric(weights, bits)
        results["uniform_asym"].append(mse)

        _, mse = quantize_uniform_symmetric(weights, bits)
        results["uniform_sym"].append(mse)

        _, mse = quantize_symmetric_clipped(weights, bits, clip_sigma=3.0)
        results["clipped_3sigma"].append(mse)

        _, mse, eff = quantize_group(weights, bits, group_size=128)
        results["group_128"].append(mse)
        results["group_128_rate"].append(eff)

        _, mse, eff = quantize_group(weights, bits, group_size=32)
        results["group_32"].append(mse)
        results["group_32_rate"].append(eff)

        _, mse = quantize_lloyd_max(weights, bits)
        results["lloyd_max"].append(mse)

    return results


def plot_rate_distortion_scalar(results, title="Rate-Distortion (Scalar and Group)"):
    plt.figure(figsize=(10, 6))

    bits = np.array(results["bits"], dtype=np.float32)
    plt.semilogy(bits, results["shannon_d"], 'k--', linewidth=2, marker='o', label='Shannon bound')

    methods = [
        ("uniform_asym", "Uniform (naive)", 'red'),
        ("uniform_sym", "Uniform (symmetric)", 'orange'),
        ("clipped_3sigma", "Clipped 3sigma", 'green'),
        ("lloyd_max", "Lloyd-Max", 'cyan'),
    ]

    for key, label, color in methods:
        plt.semilogy(bits, results[key], '-', linewidth=1.5, marker='s', label=label, color=color)

    plt.semilogy(results["group_128_rate"], results["group_128"], '-', linewidth=1.5, marker='^', label='Group g=128', color='blue')
    plt.semilogy(results["group_32_rate"], results["group_32"], '-', linewidth=1.5, marker='^', label='Group g=32', color='purple')

    plt.xlabel('Rate (bits per weight)')
    plt.ylabel('Distortion (MSE)')
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.legend(loc='upper right')
    plt.show()

In [None]:
# Vector and product quantization helpers

def measure_vector_quantizers(weights, dims=(2, 4, 8), codebook_bits=6):
    curves = []
    for d in dims:
        print(f"  VQ: dim={d}, codebook_bits={codebook_bits}")
        _, mse, eff_bits = quantize_kmeans_vq(
            weights,
            bits=codebook_bits,
            dim=d,
            sample_vectors=10000,
            eval_vectors=10000,
            seed=0,
        )
        curves.append({
            "label": f"VQ d={d} (codebook {codebook_bits} bits)",
            "rates": [eff_bits],
            "mses": [mse],
        })
    return curves


def measure_product_quantizers(weights, dims=(2, 4, 8), codebook_bits=6, codebooks=4):
    curves = []
    for d in dims:
        print(f"  PQ: dim={d}, codebook_bits={codebook_bits}, codebooks={codebooks}")
        mse, eff_bits = quantize_product_quantization(
            weights,
            bits=codebook_bits,
            dim=d,
            codebooks=codebooks,
            sample_vectors=10000,
            eval_vectors=10000,
            seed=0,
        )
        curves.append({
            "label": f"PQ d={d} (codebook {codebook_bits} bits, {codebooks} codebooks)",
            "rates": [eff_bits],
            "mses": [mse],
        })
    return curves

In [None]:
# Rate-distortion analysis for a single layer
layer = 8
weights = mlp_weights[layer]

print(f"Analyzing MLP layer {layer}")
print(f"Shape: {weights.shape}, Total weights: {weights.size:,}")
print(f"Variance: {np.var(weights.astype(np.float32)):.6e}")

scalar_results = measure_scalar_quantizers(weights, bits_list=[2, 3, 4, 5, 6, 8])
plot_rate_distortion_scalar(scalar_results, title=f"Rate-Distortion: MLP Layer {layer}")

vq_curves = measure_vector_quantizers(weights, dims=(2, 4, 8), codebook_bits=6)
pq_curves = measure_product_quantizers(weights, dims=(2, 4, 8), codebook_bits=6, codebooks=4)

In [None]:
# Overlay vector and product quantization points on scalar curves

plt.figure(figsize=(10, 6))

bits = np.array(scalar_results["bits"], dtype=np.float32)
plt.semilogy(bits, scalar_results["shannon_d"], 'k--', linewidth=2, marker='o', label='Shannon bound')

for key, label, color in [
    ("uniform_sym", "Uniform (symmetric)", 'orange'),
    ("clipped_3sigma", "Clipped 3sigma", 'green'),
    ("lloyd_max", "Lloyd-Max", 'cyan'),
]:
    plt.semilogy(bits, scalar_results[key], '-', linewidth=1.5, marker='s', label=label, color=color)

plt.semilogy(scalar_results["group_128_rate"], scalar_results["group_128"], '-', linewidth=1.5, marker='^', label='Group g=128', color='blue')
plt.semilogy(scalar_results["group_32_rate"], scalar_results["group_32"], '-', linewidth=1.5, marker='^', label='Group g=32', color='purple')

for curve in vq_curves + pq_curves:
    plt.semilogy(curve["rates"], curve["mses"], 'D', markersize=7, label=curve["label"])

plt.xlabel('Rate (bits per weight)')
plt.ylabel('Distortion (MSE)')
plt.title(f"Rate-Distortion (MLP layer {layer})")
plt.grid(True, alpha=0.3)
plt.legend(loc='upper right')
plt.show()

In [None]:
# Gap at target distortions

sigma_sq = float(np.var(weights.flatten().astype(np.float32)))

targets = [0.001 * sigma_sq, 0.01 * sigma_sq]

print("Target distortions:")
for t in targets:
    print(f"  D = {t:.3e}")

curves = [
    ("Uniform (symmetric)", scalar_results["bits"], scalar_results["uniform_sym"]),
    ("Clipped 3sigma", scalar_results["bits"], scalar_results["clipped_3sigma"]),
    ("Group g=128", scalar_results["group_128_rate"], scalar_results["group_128"]),
    ("Group g=32", scalar_results["group_32_rate"], scalar_results["group_32"]),
    ("Lloyd-Max", scalar_results["bits"], scalar_results["lloyd_max"]),
]

for t in targets:
    print(f"\n=== Gap summary at D={t:.3e} ===")
    for label, rates, mses in curves:
        r_emp = interpolate_rate_for_distortion(rates, mses, t)
        if r_emp is None:
            print(f"  {label:<16} : insufficient coverage")
            continue
        r_shannon = 0.5 * math.log2(sigma_sq / t)
        gap = r_emp - r_shannon
        print(f"  {label:<16} : R_emp={r_emp:.2f} bits, gap={gap:.2f} bits")

In [None]:
# Per-group comparison (MLP vs Attention vs Embeddings)

bits_list = list(range(1, 9))


def select_layers(layer_dict, max_layers=6):
    layers = sorted(layer_dict.keys())
    if len(layers) <= max_layers:
        return layers
    idx = np.linspace(0, len(layers) - 1, max_layers).round().astype(int)
    return [layers[i] for i in idx]


def avg_rd_curve(weights_list, bits_list, quantizer_fn):
    mses = []
    for bits in bits_list:
        layer_mses = []
        for W in weights_list:
            _, mse = quantizer_fn(W, bits)
            layer_mses.append(mse)
        mses.append(float(np.mean(layer_mses)))
    return mses

mlp_layers = select_layers(mlp_weights, max_layers=6)
q_layers = select_layers(attn_weights["q"], max_layers=6)

mlp_list = [mlp_weights[i] for i in mlp_layers]
q_list = [attn_weights["q"][i] for i in q_layers]

mlp_curve = avg_rd_curve(mlp_list, bits_list, quantize_uniform_symmetric)
q_curve = avg_rd_curve(q_list, bits_list, quantize_uniform_symmetric)

emb_curve = None
if "token" in embedding_weights:
    emb_curve = [quantize_uniform_symmetric(embedding_weights["token"], b)[1] for b in bits_list]

plt.figure(figsize=(10, 6))
plt.semilogy(bits_list, mlp_curve, '-o', label=f"MLP down (avg {len(mlp_layers)} layers)")
plt.semilogy(bits_list, q_curve, '-o', label=f"Attention Q (avg {len(q_layers)} layers)")
if emb_curve is not None:
    plt.semilogy(bits_list, emb_curve, '-o', label="Token embeddings")

plt.xlabel('Rate (bits per weight)')
plt.ylabel('Distortion (MSE)')
plt.title('Per-group comparison (uniform symmetric)')
plt.grid(True, alpha=0.3)
plt.legend(loc='upper right')
plt.show()