```
Copyright 2023 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
```

# Pytorch FSQ

In [None]:
!pip install einops

In [None]:
from typing import List, Optional
import torch
import torch.nn as nn
from torch.nn import Module
from torch import Tensor, int32
from einops import rearrange

def round_ste(z: Tensor) -> Tensor:
    """Round with straight through gradients."""
    zhat = z.round()
    return z + (zhat - z).detach()

# main class
class FSQ(Module):
    def __init__(
        self,
        levels: List[int],
        eps: float = 1e-3,
    ):
        super().__init__()
        self._eps = eps

        _levels = torch.tensor(levels, dtype=int32)
        self.register_buffer("_levels", _levels, persistent = False)

        _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
        self.register_buffer("_basis", _basis, persistent = False)

        self.codebook_dim = len(levels)
        self.codebook_size = self._levels.prod().item()

        implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out = False)
        self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)

    def bound(self, z: Tensor) -> Tensor:
        """Bound `z`, an array of shape (..., d)."""
        half_l = (self._levels - 1) * (1 - self._eps) / 2
        offset = torch.where(self._levels % 2 == 1.0, 0.0, 0.5)
        shift = (offset / half_l).tan()
        return (z + shift).tanh() * half_l - offset

    def quantize(self, z: Tensor) -> Tensor:
        """Quantizes z, returns quantized zhat, same shape as z."""
        quantized = round_ste(self.bound(z))
        half_width = self._levels // 2 # Renormalize to [-1, 1].
        return quantized / half_width

    def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
        # Scale and shift to range [0, ..., L-1]
        half_width = self._levels // 2
        return (zhat_normalized * half_width) + half_width

    def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
        half_width = self._levels // 2
        return (zhat - half_width) / half_width

    def codes_to_indices(self, zhat: Tensor) -> Tensor:
        """Converts a `code` to an index in the codebook."""
        assert zhat.shape[-1] == self.codebook_dim
        zhat = self._scale_and_shift(zhat)
        return (zhat * self._basis).sum(dim=-1).to(int32)

    def indices_to_codes(
        self,
        indices: Tensor,
        project_out = True
    ) -> Tensor:
        """Inverse of `codes_to_indices`."""
        indices = rearrange(indices, '... -> ... 1')
        codes_non_centered = (indices // self._basis) % self._levels
        codes = self._scale_and_shift_inverse(codes_non_centered)
        return codes

# Jax FSQ



In [None]:
import itertools
import jax
import jax.numpy as jnp
import numpy as np

In [None]:
Codeword = jax.Array
Indices = jax.Array


def round_ste(z):
  """Round with straight through gradients."""
  zhat = jnp.round(z)
  return z + jax.lax.stop_gradient(zhat - z)


class FSQ:
  """Quantizer."""

  def __init__(self, levels: list[int], eps: float = 1e-3):
    self._levels = levels
    self._eps = eps
    self._levels_np = np.asarray(levels)
    self._basis = np.concatenate(
        ([1], np.cumprod(self._levels_np[:-1]))).astype(np.uint32)

    self._implicit_codebook = self.indexes_to_codes(
        np.arange(self.codebook_size))

  @property
  def num_dimensions(self) -> int:
    """Number of dimensions expected from inputs."""
    return len(self._levels)

  @property
  def codebook_size(self) -> int:
    """Size of the codebook."""
    return np.prod(self._levels)

  @property
  def codebook(self):
    """Returns the implicit codebook. Shape (prod(levels), num_dimensions)."""
    return self._implicit_codebook

  def bound(self, z: jax.Array) -> jax.Array:
    """Bound `z`, an array of shape (..., d)."""
    half_l = (self._levels_np - 1) * (1 - self._eps) / 2
    offset = jnp.where(self._levels_np % 2 == 1, 0.0, 0.5)
    shift = jnp.tan(offset / half_l)
    return jnp.tanh(z + shift) * half_l - offset

  def quantize(self, z: jax.Array) -> Codeword:
    """Quanitzes z, returns quantized zhat, same shape as z."""
    quantized = round_ste(self.bound(z))

    # Renormalize to [-1, 1].
    half_width = self._levels_np // 2
    return quantized / half_width

  def _scale_and_shift(self, zhat_normalized):
    # Scale and shift to range [0, ..., L-1]
    half_width = self._levels_np // 2
    return (zhat_normalized * half_width) + half_width

  def _scale_and_shift_inverse(self, zhat):
    half_width = self._levels_np // 2
    return (zhat - half_width) / half_width

  def codes_to_indexes(self, zhat: Codeword) -> Indices:
    """Converts a `code` to an index in the codebook."""
    assert zhat.shape[-1] == self.num_dimensions
    zhat = self._scale_and_shift(zhat)
    return (zhat * self._basis).sum(axis=-1).astype(jnp.uint32)

  def indexes_to_codes(self, indices: Indices) -> Codeword:
    """Inverse of `indexes_to_codes`."""
    indices = indices[..., jnp.newaxis]
    codes_non_centered = np.mod(
        np.floor_divide(indices, self._basis), self._levels_np
    )
    return self._scale_and_shift_inverse(codes_non_centered)

# Example usage

In [None]:
fsq = FSQ(levels=[3, 5, 4])

z = torch.tensor([0.25, 0.6, -7])
zhat = fsq.quantize(z)
print(f"Quantized {z} -> {zhat}")

# We can map to an index in the codebook.
idx = fsq.codes_to_indices(zhat)
print(f"Code {zhat} is the {idx}-th index.")

# Back to code
code_out = fsq.indices_to_codes(idx)
print(f"Index {idx} mapped back to {zhat}.")

# Quantizing a multi-dimensional bottleneck

In [None]:
import numpy as np

fsq = FSQ(levels=[5, 4, 3])

d = fsq.codebook_dim
z = torch.rand(3, 8, 8, d)
zhat = fsq.quantize(z)
assert zhat.shape == (3, 8, 8, d)

indices = fsq.codes_to_indices(zhat)
assert indices.shape == (3, 8, 8)

zhat_out = fsq.indices_to_codes(indices)
assert zhat_out.shape == zhat.shape

np.testing.assert_allclose(zhat.numpy(), zhat_out.numpy())

# Validating codebook

In [None]:
fsq = FSQ(levels=[3, 4])
print(fsq.implicit_codebook)