In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class VectorQuantizer(nn.Module):
    def __init__(self, num_codes=256, code_dim=64, beta=0.25, remap=None):
        """
        num_codes : số lượng mã trong codebook (K)
        code_dim  : chiều của mỗi vector mã (D)
        beta      : hệ số commitment loss
        remap     : tùy chọn ánh xạ lại index (mặc định None)
        """
        super().__init__()
        self.num_codes = num_codes
        self.code_dim = code_dim
        self.beta = beta
        self.remap = remap

        # Codebook: K x D
        self.embedding = nn.Embedding(num_codes, code_dim)
        self.embedding.weight.data.uniform_(-1.0 / num_codes, 1.0 / num_codes)

    def forward(self, z_e):
        """
        z_e: (B, D, H, W) - latent từ encoder
        """
        B, D, H, W = z_e.shape
        # -> (B*H*W, D)
        z_flat = z_e.permute(0, 2, 3, 1).contiguous().view(-1, D)

        # Tính khoảng cách (N, K)
        dist = (
            torch.sum(z_flat ** 2, dim=1, keepdim=True)
            + torch.sum(self.embedding.weight ** 2, dim=1)
            - 2 * torch.matmul(z_flat, self.embedding.weight.t())
        )

        # Chọn index gần nhất
        indices = torch.argmin(dist, dim=1)  # (N,)
        z_q = self.embedding(indices).view(B, H, W, D)  # (B,H,W,D)

        # Straight-through Estimator trick
        z_q_st = z_e.permute(0, 2, 3, 1) + (z_q - z_e.permute(0, 2, 3, 1)).detach()

        # Loss
        loss_commit = self.beta * F.mse_loss(z_e.permute(0, 2, 3, 1), z_q.detach())
        loss_codebook = F.mse_loss(z_q, z_e.permute(0, 2, 3, 1).detach())
        loss = loss_commit + loss_codebook

        # (B,D,H,W)
        z_q_st = z_q_st.permute(0, 3, 1, 2).contiguous()

        return z_q_st, loss, indices





In [None]:
torch.manual_seed(0)
encoder_output = torch.randn(2, 64, 32, 16)  # (B=2, D=64, H=32, W=16)

vq = VectorQuantizer(num_codes=256, code_dim=64)

z_q, loss, indices = vq(encoder_output)

print("Quantized latent shape:", z_q.shape)   # (2, 64, 16, 16)
print("Indices shape:", indices.shape)        # (2x32x16)
print("Loss:", loss.item())



Quantized latent shape: torch.Size([2, 64, 32, 16])
Indices shape: torch.Size([1024])
Loss: 1.2449116706848145
