In [92]:
import numpy as np
import torch.nn as nn
import torch
# ==========================================================
# 基础量化器：Two-stage & Mean-value
# ==========================================================

def uniform_quantizer(x,qmin,qmax,Q):
    x=np.clip(x,qmin,qmax)
    indices=np.around((x-qmin)/(qmax-qmin)*(Q-1))
    return(indices,qmin,qmax)

def two_stage_quantizer(a,Q_ep=16,Q_entry=64):
    a_min= np.min(a,axis=0)
    a_max= np.max(a,axis=0)
    min_indices,min_low,min_up=uniform_quantizer(a_min,np.min(a_min),np.max(a_min),Q_ep)
    max_indices,max_low,max_up=uniform_quantizer(a_max,np.min(a_max),np.max(a_max),Q_ep)
    min_q=min_indices/(Q_ep-1)*(min_up-min_low)+min_low
    max_q=max_indices/(Q_ep-1)*(max_up-max_low)+max_low
    entry_indices,entry_low,entry_up=uniform_quantizer(a,min_q,max_q,Q_entry)
    Q=entry_indices/Q_entry*(entry_up-entry_low)+entry_low
    return(Q)


def mean_value_quantizer(a, Q0=8):
    """均值量化器 (Mean-value quantizer)"""
    a_mean = np.mean(a,axis=0)
    mean_indices,mean_low,mean_up=uniform_quantizer(a_mean,np.min(a_mean),np.max(a_mean),Q0)
    mean_q=mean_indices/Q0*(mean_up-mean_low)+mean_low
    a_q = np.repeat(mean_q.reshape(1, -1),a.shape[0], axis=0)
    return a_q,mean_q


# ==========================================================
# (P) 问题求解：Water-filling 分配 Q_i
# ==========================================================
def solve_quantization_levels(a_tilde,a_tilde_0, B, Cava, tol=1e-4, max_iter=100):
    """
    求解 (P): 最优 Q_i 分配（基于 KKT 条件的 water-filling）
    a_tilde: 向量 range 数组
    B: batch size
    Cava: 可用通信预算 (bit)
    """
    M = len(a_tilde)
    nu_low, nu_high = 1e-12, 1e6

    def compute_Q(nu):
        u = (a_tilde**2 * np.log(2)) / (2 * nu)
        u0=(a_tilde_0**2*B* np.log(2)) / (nu)
        u=np.insert(u,0,u0)
        v = (u * np.sqrt(81 - 12*u) + 9*u) ** (1/3)
        Q = ((2/3)**(1/3)) * (u / v) + v / (2**(1/3) * 3**(2/3)) + 1
        Q = np.clip(Q, 2, 2**32)

        return Q
    for _ in range(max_iter):
        nu_mid = (nu_low + nu_high) / 2
        Q_mid = compute_Q(nu_mid)
        bit_sum = np.sum(np.log2(Q_mid))  # 总通信开销近似
        if bit_sum > Cava:  # 超出预算 -> 增大 ν
            nu_low = nu_mid
        else:
            nu_high = nu_mid
        if abs(bit_sum - Cava) < tol:
            break

    return compute_Q(nu_mid), nu_mid


# ==========================================================
# 自动确定 M* ：搜索最优 M 来最小化量化误差
# ==========================================================
def auto_determine_M_and_Q(a_ranges,a_tilde_0, B, D_hat, Cava):
    """
    自动搜索最优 M* 并为其求解 (P)
    """
    candidates = np.unique(np.linspace(1, D_hat // 2, num=8, dtype=int))
    best_M, best_Q, best_err = 0, None, np.inf
    for M in candidates:
        a_tilde = a_ranges[:M]
        a_bar=a_ranges[M:]
        Q_all, _ = solve_quantization_levels(a_tilde,a_tilde_0, B, Cava)
        Q0=Q_all[0]
        Q_entry=Q_all[1:]
        # 计算误差上界 (式19)
        err_two = np.sum((a_tilde**2 * B) / (4 * (Q_entry - 1)**2))
        err_mean_1 = np.sum((a_bar**2 * B)/2) # mean-quantizer误差近似
        err_mean_2=a_tilde_0**2*B*(D_hat-M)/(2*(Q0 - 1)**2)
        total_err = err_two + err_mean_1+err_mean_2
        if total_err < best_err:
            best_err = total_err
            best_M = M
            best_Q = Q_all

    return best_M, best_Q


# ==========================================================
# 主算法：Adaptive Feature-Wise Quantization
# ==========================================================
def adaptive_featurewise_quantization(A, Cava, Q_ep=200):
    """
    自适应特征量化算法（完整版本）
    输入：
        A: (B×D̂) 中间特征矩阵
        Cava: 总通信预算（bit）
    输出：
        Q: 量化后矩阵
        mu: 均值量化向量
        M*: 使用 two-stage quantizer 的列数
        Q_entry_list: 对应的量化级数组
    """
    B, D_hat = A.shape
    A=A
    ranges = np.max(A, axis=0) - np.min(A, axis=0)
    idx_sorted = np.argsort(-ranges)  # 按range从大到小排序
    A_sorted = A[:, idx_sorted]
    ranges_sorted = ranges[idx_sorted]
    a_tilde_0=np.max(np.mean(A,axis=0))-np.min(np.mean(A,axis=0))
    # ① 自动确定 M* 和每列最优 Q_i (P问题)
    M_star, Q_entry_list = auto_determine_M_and_Q(ranges_sorted,a_tilde_0, B, D_hat, Cava)
    # ② 执行量化
    Q = np.zeros_like(A_sorted)
    mu = np.zeros(D_hat)
    Q[:,:M_star]=two_stage_quantizer(A_sorted[:, 0:M_star], Q_ep=Q_ep,Q_entry=np.around(Q_entry_list[1:]))
    Q[:,M_star:],mu= mean_value_quantizer(A_sorted[:,M_star:], np.around(Q_entry_list[0]))

    inv_idx = np.argsort(idx_sorted)
    Q = Q[:, inv_idx]
    return Q, M_star, Q_entry_list

class FWQ(nn.Module):
    def __init__(self,token_dim,code_dim,discrete_size):
        super(FWQ,self).__init__()
        self.discrete_size=discrete_size
    def forward(self,x,return_indice=False):
        x_shape=x.shape
        flattened_x=x.view(-1,x.shape[2]) #[B*S,H]
        flattened_x_quantized,M_star, Q_entry_list=adaptive_featurewise_quantization(flattened_x,int(np.log2(self.discrete_size)))       
        output=(torch.tensor(flattened_x_quantized).to(x.device)-flattened_x).detach()+flattened_x
        L_comm=0
        L_code=0
        return output.reshape(x_shape),L_code,L_comm

In [95]:
X=np.random.randn(4,50)
X_q,_,_=adaptive_featurewise_quantization(X,16)

  v = (u * np.sqrt(81 - 12*u) + 9*u) ** (1/3)


In [51]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def generate_nf_table(bits: int, device="cpu", dtype=torch.float16):
    n_levels = 2 ** bits
    # 在 (-1, 1) 的概率区间上均匀采样
    probs = torch.linspace(0, 1, n_levels + 1)[1:-1]
    # 根据标准正态分布的分位数函数生成
    values = torch.erfinv(2 * probs - 1)
    values = values / values.abs().max()  # 归一化到 [-1, 1]
    values = torch.cat([torch.tensor([-1.0]), values, torch.tensor([1.0])])
    return values.to(device=device, dtype=dtype)

class NFNDoubleQuantizer(nn.Module):
    def __init__(self, bits=4, block_size=64, use_double_quant=True):
        super().__init__()
        self.bits = bits
        self.block_size = block_size
        self.use_double_quant = use_double_quant
        self.table=generate_nf_table(bits)

    def quantize(self, x):
        orig_shape = x.shape
        x = x.view(orig_shape[0],orig_shape[1]//self.block_size, self.block_size)

        q_idx_list, scale_list, min_list = [], [], []
        x_min, x_max = x.min(dim=2).values.unsqueeze(-1), x.max(dim=2).values.unsqueeze(-1)
        scales = (x_max - x_min).squeeze(-1)
        x_norm = 2 * (x - x_min) / (x_max - x_min + 1e-8) - 1
        dist = torch.abs(x_norm.unsqueeze(-1) - self.table.to(x.device))
        q_idx = torch.argmin(dist, dim=-1).to(torch.uint8)
        mins = x_min
        # Double Quantization: 再量化scale (8bit)
        if self.use_double_quant:
            s_min, s_max = scales.min(dim=-1).values.unsqueeze(-1), scales.max(dim=-1).values.unsqueeze(-1)
            scales_q = ((scales - s_min) / (s_max - s_min + 1e-8) * 255).round().to(torch.uint8)
        else:
            scales_q, s_min, s_max = None, None, None

        return q_idx, scales_q, s_min, s_max, mins

    def dequantize(self, q_idx, scales_q, s_min, s_max, mins):
        if scales_q is not None:
            scales = s_min + (scales_q.float() / 255) * (s_max - s_min)
        else:
            scales = s_min
        scales=scales.unsqueeze(-1)
        w_block = self.table[q_idx.to(device=self.table.device,dtype=torch.long)].to(torch.float32)
        w_block = (w_block + 1) / 2 * scales +mins
 
        return w_block.reshape(-1,q_idx.shape[1]*self.block_size)
    
class Qlora_quantize(nn.Module):
    def __init__(self, bits=4, block_size=64, use_double_quant=True):
        super().__init__()
        self.quantizer = NFNDoubleQuantizer(bits, block_size, use_double_quant)

    def forward(self, x):
        flattened_x=x.view(-1,x.shape[2])
        q_idx, s_q, s_min, s_max, mins = self.quantizer.quantize(flattened_x)
        flattened_x_q = self.quantizer.dequantize(q_idx, s_q, s_min, s_max, mins)
        output=flattened_x + (flattened_x_q - flattened_x).detach()
        return output.reshape(x.shape)

In [58]:
X=torch.rand(8,40,1280)
quant=Qlora_quantize()
q=quant(X)

torch.Size([320, 1])


tensor([[[0.8811, 0.1176, 0.6087,  ..., 0.9675, 0.6174, 0.4613],
         [0.3412, 0.0083, 0.1451,  ..., 0.8400, 0.9867, 0.7271],
         [0.2848, 0.7816, 0.8029,  ..., 0.0015, 0.7203, 0.9308],
         ...,
         [0.2455, 0.6643, 0.5318,  ..., 0.9995, 0.7583, 0.2779],
         [0.6788, 0.6938, 0.4952,  ..., 0.5393, 0.1716, 0.7679],
         [0.9702, 0.9675, 0.5222,  ..., 0.7753, 0.7366, 0.8458]],

        [[0.8897, 0.0724, 0.7113,  ..., 0.2052, 0.6756, 0.8374],
         [0.4419, 0.9382, 0.1345,  ..., 0.4369, 0.0828, 0.7452],
         [0.7984, 0.5462, 0.6909,  ..., 0.4985, 0.4916, 0.0909],
         ...,
         [0.5281, 0.8096, 0.6119,  ..., 0.2170, 0.2455, 0.5626],
         [0.5345, 0.1439, 0.9436,  ..., 0.8310, 0.2409, 0.2408],
         [0.6460, 0.4772, 0.4199,  ..., 0.2782, 0.3231, 0.4355]],

        [[0.8019, 0.5612, 0.5104,  ..., 0.7873, 0.2197, 0.2938],
         [0.3004, 0.7838, 0.6811,  ..., 0.7415, 0.7341, 0.5568],
         [0.3735, 0.6604, 0.7661,  ..., 0.9719, 0.7927, 0.