```
Copyright 2023 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
```

# FSQ



In [None]:
import itertools
import jax
import jax.numpy as jnp
import numpy as np

In [None]:
Codeword = jax.Array
Indices = jax.Array


def round_ste(z):
  """Round with straight through gradients."""
  zhat = jnp.round(z)
  return z + jax.lax.stop_gradient(zhat - z)


class FSQ:
  """Quantizer."""

  def __init__(self, levels: list[int], eps: float = 1e-3):
    self._levels = levels
    self._eps = eps
    self._levels_np = np.asarray(levels)
    self._basis = np.concatenate(
        ([1], np.cumprod(self._levels_np[:-1]))).astype(np.uint32)

    self._implicit_codebook = self.indexes_to_codes(
        np.arange(self.codebook_size))

  @property
  def num_dimensions(self) -> int:
    """Number of dimensions expected from inputs."""
    return len(self._levels)

  @property
  def codebook_size(self) -> int:
    """Size of the codebook."""
    return np.prod(self._levels)

  @property
  def codebook(self):
    """Returns the implicit codebook. Shape (prod(levels), num_dimensions)."""
    return self._implicit_codebook

  def bound(self, z: jax.Array) -> jax.Array:
    """Bound `z`, an array of shape (..., d)."""
    half_l = (self._levels_np - 1) * (1 - self._eps) / 2
    offset = jnp.where(self._levels_np % 2 == 1, 0.0, 0.5)
    shift = jnp.tan(offset / half_l)
    return jnp.tanh(z + shift) * half_l - offset

  def quantize(self, z: jax.Array) -> Codeword:
    """Quanitzes z, returns quantized zhat, same shape as z."""
    quantized = round_ste(self.bound(z))

    # Renormalize to [-1, 1].
    half_width = self._levels_np // 2
    return quantized / half_width

  def _scale_and_shift(self, zhat_normalized):
    # Scale and shift to range [0, ..., L-1]
    half_width = self._levels_np // 2
    return (zhat_normalized * half_width) + half_width

  def _scale_and_shift_inverse(self, zhat):
    half_width = self._levels_np // 2
    return (zhat - half_width) / half_width

  def codes_to_indexes(self, zhat: Codeword) -> Indices:
    """Converts a `code` to an index in the codebook."""
    assert zhat.shape[-1] == self.num_dimensions
    zhat = self._scale_and_shift(zhat)
    return (zhat * self._basis).sum(axis=-1).astype(jnp.uint32)

  def indexes_to_codes(self, indices: Indices) -> Codeword:
    """Inverse of `indexes_to_codes`."""
    indices = indices[..., jnp.newaxis]
    codes_non_centered = np.mod(
        np.floor_divide(indices, self._basis), self._levels_np
    )
    return self._scale_and_shift_inverse(codes_non_centered)

# Example usage

In [None]:
fsq = FSQ(levels=[3, 5, 4])

z = np.asarray([0.25, 0.6, -7])
print(z.shape)
zhat = fsq.quantize(z)
print(f"Quantized {z} -> {zhat}")

# We can map to an index in the codebook.
idx = fsq.codes_to_indexes(zhat)
print(f"Code {zhat} is the {idx}-th index.")

# Back to code
code_out = fsq.indexes_to_codes(idx)
print(f"Index {idx} mapped back to {zhat}.")

# Quantizing a multi-dimensional bottleneck

In [None]:
fsq = FSQ(levels=[5, 4, 3])

d = fsq.num_dimensions
print(d)
z = np.random.uniform(size=(3, 8, 8, d))
zhat = fsq.quantize(z)
assert zhat.shape == (3, 8, 8, d)

indices = fsq.codes_to_indexes(zhat)
assert indices.shape == (3, 8, 8)

zhat_out = fsq.indexes_to_codes(indices)
assert zhat_out.shape == zhat.shape

np.testing.assert_allclose(zhat, zhat_out)

# Validating codebook

In [None]:
fsq = FSQ(levels=[3, 4])
print(fsq.codebook)

In [None]:
import torch
from vector_quantize_pytorch import VectorQuantize
from transformers import CLIPProcessor, CLIPModel

In [None]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

In [None]:
obs = torch.randn(2,10,3,244,244)
obs = obs.view(20,3,244,244)
obs_embed = model.get_image_features(obs)
obs_embed = obs_embed.view(2, 10, -1)

In [None]:
obs_embed.shape

In [None]:
import torch.nn as nn

# Assume you have image embeddings with shape [2, 10, 512]
# image_embeddings = torch.rand((2, 10, 512))

# Define the Transformer Encoder model
class TransformerEncoderModel(nn.Module):
    def __init__(self, input_size, num_layers):
        super(TransformerEncoderModel, self).__init__()
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=input_size, nhead=2, batch_first=True),
            num_layers=num_layers
        )

    def forward(self, x):
        # Forward pass through the Transformer Encoder
        x = self.transformer_encoder(x)
        return x

# Define the Transformer Encoder model with the specified input size and number of layers
transformer_model = TransformerEncoderModel(input_size=512, num_layers=2)

# Forward pass through the Transformer Encoder with the image embeddings
encoder_output = transformer_model(obs_embed)

# Print the output shape
print("Transformer Encoder output shape:", encoder_output.shape)

In [None]:
x = encoder_output.view(2, 10*512)

In [None]:
x.shape

In [None]:
prediction_head = nn.Linear(10*512, 4)
out = prediction_head(x)

In [None]:
out.shape

In [None]:
out = out.unsqueeze(0)

In [None]:
zt = torch.randn(1,2,4)
quantizer(zt)

In [None]:
from vector_quantize_pytorch import FSQ

levels = [8,5,5,5] # see 4.1 and A.4.1 in the paper
quantizer = FSQ(levels)
xhat, indices = quantizer(out)

In [None]:
import torch.nn as nn

# Assume you have image embeddings with shape [2, 10, 512]
# image_embeddings = torch.rand((2, 10, 512))

# Define the Transformer Encoder model
class TransformerEncoderModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(TransformerEncoderModel, self).__init__()
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=input_size, nhead=2, batch_first=True),
            num_layers=num_layers
        )
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # Forward pass through the Transformer Encoder
        x = self.transformer_encoder(x)

        # Apply linear layers
        x = self.linear1(x)
        x = self.linear2(x)

        return x

# Define the Transformer Encoder model with the specified input size, hidden size, output size, and number of layers
transformer_model = TransformerEncoderModel(input_size=512, hidden_size=256, output_size=4, num_layers=2)

# Forward pass through the Transformer Encoder with the image embeddings
encoder_output = transformer_model(obs_embed)

# Print the output shape
print("Transformer Encoder output shape:", encoder_output.shape)

In [None]:
from vector_quantize_pytorch import FSQ

levels = [8,5,5,5] # see 4.1 and A.4.1 in the paper
quantizer = FSQ(levels)

# x = torch.randn(1, 1024, 4) # 4 since there are 4 levels
xhat, indices = quantizer(encoder_output)

In [None]:
print(xhat.shape)    # (1, 1024, 4) - (batch, seq, dim)
print(indices.shape)

In [None]:
xhat

In [None]:
indices

In [None]:
vq = VectorQuantize(
    dim = 512,
    codebook_size = 512,     # codebook size
    decay = 0.8,             # the exponential moving average decay, lower means the dictionary will change faster
    commitment_weight = 1.   # the weight on the commitment loss
)

In [None]:
quantized, indices, commit_loss = vq(encoder_output)

In [None]:
indices

In [None]:
quantized.shape

In [None]:
import torch
import torch.nn as nn

# Assume you have a tensor of shape [2, 10, 512]
input_tensor = torch.rand((2, 10, 512))

# Feature-wise Linear Modulation (FiLM)
film_layer = nn.Linear(512, 4)
output_tensor = film_layer(input_tensor)

print("Output tensor shape (FiLM):", output_tensor.shape)


In [None]:
!pip install vector-quantize-pytorch