In [1]:
import torch
import torch.nn.functional as F

m,n = 4096 // 4, 4096 // 4
w = torch.randn(m, n).cuda()

vecdim = 8
index_bits = 12 #16 #8
n_layers = 256

indices_shape = m * n // vecdim

with torch.no_grad():
	indices = torch.randint(0, 2**index_bits, (n_layers, indices_shape)).cuda()

# codebook = torch.nn.Parameter(torch.randn((n_layers, 2**index_bits, vecdim), dtype=torch.float32).cuda(), requires_grad=True)
# scales = torch.nn.Parameter(torch.ones((n_layers, m, 1), dtype=torch.float32).cuda() / n_layers, requires_grad=True)
# optim = torch.optim.Adam(params=[codebook, scales], lr=1e-2)


codebook = torch.nn.Parameter(torch.randn((2**index_bits, vecdim), dtype=torch.float32).cuda(), requires_grad=True)
optim = torch.optim.Adam(params=[codebook,], lr=1e-2)


loss_fn = torch.nn.MSELoss()

In [2]:
def vector_quantize(w, codebook):
    """
    Векторное квантование матрицы
    
    Args:
        w: torch.Tensor shape (m, n) - исходная матрица
        codebook: torch.Tensor shape (codebook_size, vecdim) - кодбук
    
    Returns:
        w_q: torch.Tensor shape (m, n) - квантованная матрица
        indices: torch.Tensor - выбранные индексы
    """
    m, n = w.shape
    codebook_size, vecdim = codebook.shape
    
    # Проверяем совместимость размеров
    if (m * n) % vecdim != 0:
        raise ValueError(f"Размер матрицы {m}x{n} не делится на vecdim={vecdim}")
    
    # Разбиваем матрицу на векторы
    n_vectors = m * n // vecdim
    w_vectors = w.reshape(n_vectors, vecdim)  # (n_vectors, vecdim)
    
    # Вычисляем попарные расстояния между всеми векторами и кодами
    # w_vectors: (n_vectors, vecdim)
    # codebook: (codebook_size, vecdim)
    distances = torch.cdist(w_vectors, codebook, p=2)  # (n_vectors, codebook_size)
    
    # Находим ближайшие индексы
    indices = torch.argmin(distances, dim=1)  # (n_vectors,)
    
    # Заменяем векторы на ближайшие из кодбука
    w_q_vectors = codebook[indices]  # (n_vectors, vecdim)
    
    # Восстанавливаем исходную форму
    w_q = w_q_vectors.reshape(m, n)
    
    return w_q, indices

w_q, best_indices = vector_quantize(w, codebook)

loss_fn(w_q, w)

tensor(0.1845, device='cuda:0', grad_fn=<MseLossBackward0>)

In [3]:
def quantize_with_constraints_vectorized(w, codebook, allowed_indices, vecdim):
    """
    Векторизованная версия для лучшей производительности
    """
    m, n = w.shape
    n_vectors = m * n // vecdim
    
    # Разбиваем на векторы
    w_vectors = w.reshape(-1, vecdim)  # (n_vectors, vecdim)
    
    # Получаем все разрешенные векторы для всех позиций
    # allowed_indices: (n_layers, n_vectors)
    # -> allowed_vectors: (n_layers, n_vectors, vecdim)
    allowed_vectors = codebook[allowed_indices]
    
    # Вычисляем расстояния
    # w_vectors: (n_vectors, vecdim) -> (1, n_vectors, vecdim)
    # allowed_vectors: (n_layers, n_vectors, vecdim)
    distances = F.mse_loss(
        w_vectors.unsqueeze(0).expand_as(allowed_vectors),
        allowed_vectors,
        reduction='none'
    ).mean(dim=2)  # (n_layers, n_vectors)
    
    # Находим лучшие индексы
    best_layer_indices = torch.argmin(distances, dim=0)  # (n_vectors,)
    best_indices = allowed_indices[best_layer_indices, torch.arange(n_vectors)]
    
    # Собираем результат
    best_vectors = codebook[best_indices]  # (n_vectors, vecdim)
    w_q = best_vectors.reshape(m, n)
    
    return w_q, best_indices

w_q, best_indices = quantize_with_constraints_vectorized(w, codebook, indices, vecdim)

loss_fn(w_q, w)

tensor(0.3707, device='cuda:0', grad_fn=<MseLossBackward0>)

In [4]:
n_steps = 2500

loss_fn = torch.nn.MSELoss()
for steps in range(n_steps):
	#w_q, best_indices = vector_quantize(w, codebook)
	w_q, best_indices = quantize_with_constraints_vectorized(w, codebook, indices, vecdim)
	
	# w_q = torch.gather(
	# 	codebook,
	# 	1,
	# 	indices.unsqueeze(-1).expand(*indices.shape, vecdim)
	# ).reshape(n_layers, m, n)
	# w_q = (scales * w_q).sum(0)
	
	loss = loss_fn(w, w_q)
	loss.backward()
	optim.step()
	optim.zero_grad()
	
	if steps % (n_steps // 10) == 0:
		print(loss)

tensor(0.3707, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(0.2946, device='cuda:0', grad_fn=<MseLossBackward0>)


KeyboardInterrupt: 