In [18]:
import numpy as np
import matplotlib.pyplot as plt

def simplified_bound(z, L):
    eps = 1e-3
    half_l = (L - 1) * (1 - eps) / 2
    offset = 0.5 if L % 2 == 0 else 0.0
    shift = np.tan(offset / half_l)
    z_hat = np.tanh(z + shift) * half_l - offset
    quantized = round_ste(z_hat)

    half_width = L // 2
    return quantized / half_width

def simplified_bound_no_shift(z, L):
    eps = 1e-3
    half_l = (L - 1) * (1 - eps) / 2
    offset = 0.5 if L % 2 == 0 else 0.0
    z_hat = np.tanh(z) * half_l - offset
    quantized = round_ste(z_hat)

    half_width = L // 2
    return quantized / half_width

def simplified_bound_no_shift_no_offset(z, L):
    eps = 1e-3
    half_l = (L - 1) * (1 - eps) / 2
    z_hat = np.tanh(z) * half_l

def simplified_bound_no_shift_no_offset_no_eps(z, L):
    half_l = (L - 1) / 2
    z_hat = np.tanh(z) * half_l

def simplified_bound_no_shift_no_offset_no_scale(z, L):
    z_hat = np.tanh(z)

def round_ste(z):
    z_hat = np.round(z)
    return z_hat

In [None]:
z = np.linspace(-3, 3, 1000)

z_bound_even = simplified_bound(z, 4)
z_bound_even_no_shift = simplified_bound_no_shift(z, 4)
z_bound_even_no_shift_no_offset = simplified_bound_no_shift_no_offset(z, 4)
z_bound_even_no_shift_no_offset_no_scale = simplified_bound_no_shift_no_offset_no_scale(z, 4)

z_bound_even_no_shift_rounded = z_bound_even_no_shift
z_bound_even_rounded = z_bound_even

plt.figure(figsize=(10, 6))

# plt.plot(z, z_bound_even_no_shift_no_offset_no_scale, label='Step 1: Bound Function, L=4 (tanh w/ eps)', color='purple', linestyle='--')
# plt.plot(z, z_bound_even_no_shift_no_offset, label='Step 2: Bound Function, L=4 (tanh + scale)', color='red', linestyle='--')
plt.plot(z, z_bound_even_no_shift, label='Step 3: Bound Function, L=4 (tanh + scale + offset)', color='gold', linestyle='--')
plt.plot(z, z_bound_even, label='Step 4: Bound Function, L=4 (tanh + scale + offset + shift)', color='green')
plt.plot(z, z_bound_even_rounded, label='Step 5: Round', color='black')
plt.plot(z, z_bound_even_no_shift_rounded, label='Step 5: Round (No Shift)', color='grey', linestyle='--')

# Adding the title, labels, and legend
plt.title('Effect of Bound Function for Even L')
plt.xlabel('Input z')
plt.ylabel('Transformed z')
plt.legend()

# Adding grid for better readability
plt.grid(True)

# Display the combined graph
plt.show()

In [None]:
z_bound_odd = simplified_bound(z, 5)

plt.figure(figsize=(10, 6))
plt.plot(z, z_bound_odd, label='Bound Function, L=5 (Odd)')
plt.title('Effect of Bound Function for Odd L')
plt.xlabel('Input z')
plt.ylabel('Transformed z')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FSQ(nn.Module):
    def __init__(self, levels, eps=1e-3):
        super().__init__()
        self.register_buffer('levels', torch.tensor(levels))
        self.register_buffer(
            'basis',
            torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32)
        )

        self.eps = eps
        self.codebook_size = torch.prod(self.levels)

    def round_ste(self, z):
        z_q = torch.round(z)
        return z + (z_q - z).detach()

    def quantize(self, z):
        # half_l is used to determine how to scale tanh; we
        # subtract 1 from the number of levels to account for 0
        # being a quantization bin and tanh being symmetric around 0
        half_l = (self.levels - 1) * (1 - self.eps) / 2

        # if a given level is even, it will result in a scale for tanh
        # which is halfway between integer values, so we offset
        # the tanh output down by 0.5 to line it with whole integers
        offset = torch.where(self.levels % 2 == 0, 0.5, 0.0)

        # if our level is even, we want to shift the tanh input to
        # ensure the 0 quantization bin is centered
        shift = torch.tan(offset / half_l)

        # once we have our shift and offset (in the case of an even level)
        # we can round to the nearest integer bin and allow for STE
        z_q = self.round_ste(torch.tanh(z + shift) * half_l - offset)

        # after quantization, we want to renormalize the quantized
        # values to be within the range expected by the model (ie. [-1, 1])
        half_width = self.levels // 2
        return z_q / half_width

    def scale_and_shift(self, z_q_normalized):
        half_width = self.levels // 2
        return (z_q_normalized * half_width) + half_width

    def scale_and_shift_inverse(self, z_q):
        half_width = self.levels // 2
        return (z_q - half_width) / half_width

    def code_to_idxs(self, z_q):
        z_q = self.scale_and_shift(z_q)
        return (z_q * self.basis).sum(dim=-1).to(torch.int32)

    def idxs_to_code(self, idxs):
        idxs = idxs.unsqueeze(-1)
        codes_not_centered = (idxs // self.basis) % self.levels
        return self.scale_and_shift_inverse(codes_not_centered)

    def forward(self, z):
        pass

In [None]:
fsq = FSQ(levels=[3, 5, 4])
z = torch.tensor([0.25, 0.6, 6])
z_q = fsq.quantize(z)

print(f'{z} -> {z_q}')

idx = fsq.code_to_idxs(z_q)
print(f'code {z_q} is the {idx}-th index')

code = fsq.idxs_to_code(idx)
print(f'idx {idx} mapped back to {code}')