<a href="https://colab.research.google.com/github/Ryan0v0/nninn/blob/master/vq_vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!pip3 install -U -r requirements.txt

# Step1: Splitting up neural net params into chunks

In [None]:
import torch
import torch.nn as nn

# Define the neural network architecture
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.fc1 = nn.Linear(10, 100)
        self.fc2 = nn.Linear(100, 100)
        self.fc3 = nn.Linear(100, 1)

        # Initialize the weights to be non-negative
        nn.init.uniform_(self.fc1.weight, a=0, b=1)
        nn.init.uniform_(self.fc2.weight, a=0, b=1)
        nn.init.uniform_(self.fc3.weight, a=0, b=1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Create an instance of the neural network
net = NeuralNetwork()

# Split up the neural network parameters into chunks
chunk_size = 1000
param_chunks = []
for param in net.parameters():
    flattened_param = param.view(-1)
    chunks = torch.split(flattened_param, chunk_size)
    param_chunks.extend(chunks)

# Print the number of parameter chunks
print("Number of parameter chunks:", len(param_chunks))
print("Parameter chunks:", param_chunks)


Number of parameter chunks: 15
Parameter chunks: [tensor([0.2919, 0.2992, 0.8227, 0.0579, 0.0281, 0.1178, 0.8750, 0.7092, 0.2448,
        0.7001, 0.9529, 0.3145, 0.3038, 0.7455, 0.4266, 0.1648, 0.6162, 0.3966,
        0.3237, 0.0941, 0.2232, 0.4404, 0.9375, 0.6002, 0.1611, 0.3535, 0.3681,
        0.3448, 0.3686, 0.5276, 0.0393, 0.6725, 0.8932, 0.7149, 0.4007, 0.5334,
        0.7816, 0.3900, 0.7587, 0.3536, 0.8926, 0.5112, 0.1222, 0.7303, 0.3731,
        0.6274, 0.3578, 0.0065, 0.5690, 0.8900, 0.5028, 0.2401, 0.5878, 0.7026,
        0.4572, 0.6603, 0.5703, 0.3134, 0.5029, 0.2670, 0.5947, 0.3110, 0.8365,
        0.5826, 0.0481, 0.0697, 0.2296, 0.9035, 0.8822, 0.8258, 0.0112, 0.4619,
        0.4472, 0.0317, 0.9144, 0.7091, 0.2712, 0.8672, 0.4228, 0.6326, 0.4163,
        0.7031, 0.6291, 0.8241, 0.0417, 0.8467, 0.2199, 0.4562, 0.0821, 0.6900,
        0.3033, 0.3546, 0.0202, 0.2633, 0.1026, 0.5048, 0.9257, 0.1560, 0.8167,
        0.2368, 0.7459, 0.6086, 0.9588, 0.9297, 0.2786, 0.4519, 0.3187

In [None]:
import numpy as np

# Convert param_chunks to a numpy array
param_chunks_np = np.concatenate([chunk.detach().numpy() for chunk in param_chunks])

print(type(param_chunks_np))
print("size:", param_chunks_np.shape)

<class 'numpy.ndarray'>
size: (11301,)


# Step2: learning a mapping from each chunk to an integer via VQ-VAE

In [None]:
from __future__ import print_function


import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import savgol_filter


from six.moves import xrange

# import umap

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load Data

using the weights of the above neural network as input.

In [None]:
data_variance = np.var(param_chunks_np / 255.0)

print(data_variance)

1.3333508e-06


## Vector Quantizer Layer

This layer takes a tensor to be quantized. The channel dimension will be used as the space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.

The output tensor will have the same shape as the input.

As an example for a `BCHW` tensor of shape `[16, 64, 32, 32]`, we will first convert it to an `BHWC` tensor of shape `[16, 32, 32, 64]` and then reshape it into `[16384, 64]` and all `16384` vectors of size `64`  will be quantized independently. In otherwords, the channels are used as the space in which to quantize. All other dimensions will be flattened and be seen as different examples to quantize, `16384` in this case.

In [None]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()

        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings

        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        '''
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape
        '''
        # convert inputs from HW -> HW
        inputs = inputs.permute(0, 1).contiguous()
        input_shape = inputs.shape

        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)

        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss

        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 1).contiguous(), perplexity, encodings # (0, 3, 1, 2)

We will also implement a slightly modified version  which will use exponential moving averages to update the embedding vectors instead of an auxillary loss. This has the advantage that the embedding updates are independent of the choice of optimizer for the encoder, decoder and other parts of the architecture. For most experiments the EMA version trains faster than the non-EMA version.

In [None]:
class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
        super(VectorQuantizerEMA, self).__init__()

        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings

        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.normal_()
        self._commitment_cost = commitment_cost

        self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
        self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
        self._ema_w.data.normal_()

        self._decay = decay
        self._epsilon = epsilon

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 1).contiguous() # (0, 2, 3, 1)
        input_shape = inputs.shape

        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)

        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)

        # Use EMA to update the embedding vectors
        if self.training:
            self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                                     (1 - self._decay) * torch.sum(encodings, 0)

            # Laplace smoothing of the cluster size
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self._epsilon)
                / (n + self._num_embeddings * self._epsilon) * n)

            dw = torch.matmul(encodings.t(), flat_input)
            self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)

            self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        loss = self._commitment_cost * e_latent_loss

        # Straight Through Estimator
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 1).contiguous(), perplexity, encodings # (0, 3, 1, 2)

## Encoder & Decoder Architecture

The encoder and decoder architecture is based on a ResNet and is implemented below:

In [None]:
'''
class Residual(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
        super(Residual, self).__init__()
        self._block = nn.Sequential(
            nn.ReLU(True),
            nn.Linear(in_features=in_channels,
                      out_features=num_residual_hiddens),
            nn.ReLU(True),
            nn.Linear(in_features=num_residual_hiddens,
                      out_features=num_hiddens) # Replaced with nn.Linear
        )

    def forward(self, x):
        return x + self._block(x)
        # x = x.unsqueeze(2)  # Add an extra dimension
        # output = x + self._block(x)
        # return output.squeeze(2)  # Remove the extra dimension
'''

class Residual(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
        super(Residual, self).__init__()
        self._block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv1d(in_channels=in_channels, out_channels=num_residual_hiddens,
                      kernel_size=1, stride=1, bias=False)
        )

    def forward(self, x):
        x = x.unsqueeze(2)  # Add an extra dimension
        output = x + self._block(x)
        return output.squeeze(2)  # Remove the extra dimension

class ResidualStack(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(ResidualStack, self).__init__()
        self._num_residual_layers = num_residual_layers
        self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
                             for _ in range(self._num_residual_layers)])

    def forward(self, x):
        for i in range(self._num_residual_layers):
            x = self._layers[i](x)
        return F.relu(x)

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Encoder, self).__init__()

        self._linear_1 = nn.Linear(in_channels*64, num_hiddens//2)
        self._linear_2 = nn.Linear(num_hiddens//2, num_hiddens)


        self._residual_stack = ResidualStack(in_channels=128, #num_hiddens,
                                             num_hiddens=1,
                                             num_residual_layers=1,
                                             num_residual_hiddens=1)


    def forward(self, inputs):
        x = self._linear_1(inputs.view(inputs.size(0), -1))
        print("x=", x.shape)
        x = F.relu(x)

        x = self._linear_2(x)
        x = F.relu(x)

        return self._residual_stack(x)

In [None]:
class Decoder(nn.Module):
    def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Decoder, self).__init__()

        self._linear_1 = nn.Linear(in_channels, num_hiddens)

        self._residual_stack = ResidualStack(in_channels=128, #num_hiddens,
                                             num_hiddens=1,
                                             num_residual_layers=1,
                                             num_residual_hiddens=1)

        self._linear_2 = nn.Linear(num_hiddens, num_hiddens//2)

    def forward(self, inputs):
        x = self._linear_1(inputs.view(inputs.size(0), -1))
        print("x=", x.shape)
        x = self._residual_stack(x)

        x = self._linear_2(x)
        x = F.relu(x)

        return x

## Train

We use the hyperparameters from the author's code:

In [None]:
batch_size = 256
num_training_updates = 15000

num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2

embedding_dim = 64
num_embeddings = 512

commitment_cost = 0.25

decay = 0.99

learning_rate = 1e-3

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset

# Create a TensorDataset from param_chunks_np
dataset = TensorDataset(torch.from_numpy(param_chunks_np))

# Set the batch size and other DataLoader parameters
batch_size = 64
shuffle = True
pin_memory = True

# Create the DataLoader using the custom dataset
training_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=pin_memory)

print(training_loader)

print(iter(training_loader))

data = next(iter(training_loader))

print("ori_data=", data)

print(type(next(iter(training_loader))))
# print(data[0])

data = next(iter(training_loader))
# for i in range(len(data)):
#    data[i] = data[i].to(device)
data = torch.stack(data).to(device)
print("data=", data)

# There's no label in the NN weight dataset
'''
for batch_idx, data in enumerate(training_loader):
    print("Batch Index:", batch_idx)
    print("Data:", data)
    print()
'''

<torch.utils.data.dataloader.DataLoader object at 0x7af8352ffaf0>
<torch.utils.data.dataloader._SingleProcessDataLoaderIter object at 0x7af8352fc0a0>
ori_data= [tensor([0.4888, 0.0474, 0.2346, 0.6123, 0.8177, 0.4739, 0.9732, 0.6205, 0.1436,
        0.0986, 0.5187, 0.0252, 0.7701, 0.9893, 0.9328, 0.4948, 0.3099, 0.2192,
        0.5352, 0.1977, 0.8090, 0.8964, 0.4286, 0.2837, 0.0108, 0.5531, 0.9424,
        0.3177, 0.3435, 0.4505, 0.4231, 0.5755, 0.7621, 0.2150, 0.4349, 0.4824,
        0.1983, 0.1066, 0.3323, 0.1205, 0.9500, 0.1452, 0.6512, 0.8441, 0.4323,
        0.8228, 0.2539, 0.3342, 0.7360, 0.8981, 0.3812, 0.7564, 0.3739, 0.6766,
        0.5419, 0.0210, 0.5698, 0.5816, 0.9289, 0.0235, 0.7799, 0.8610, 0.0456,
        0.4263])]
<class 'list'>
data= tensor([[ 0.0951,  0.6299,  0.4237,  0.3078,  0.7555, -0.1972,  0.6487,  0.7891,
          0.2252,  0.3612,  0.5262,  0.2401,  0.6173,  0.0974,  0.2855,  0.2621,
          0.8607,  0.5508,  0.4699,  0.1263,  0.2584,  0.4404,  0.1232,  0.316

'\nfor batch_idx, data in enumerate(training_loader):\n    print("Batch Index:", batch_idx)\n    print("Data:", data)\n    print()\n'

In [None]:
cifar_training_data = datasets.CIFAR10(root="data", train=True, download=True,
                                  transform=transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                  ]))

cifar_training_loader = DataLoader(cifar_training_data,
                             batch_size=batch_size,
                             shuffle=True,
                             pin_memory=True)

print(cifar_training_loader)

print(iter(cifar_training_loader))

(data, _) = next(iter(cifar_training_loader))
print(type(next(iter(cifar_training_loader))))
print(type(data))
print(type(_))
print("ori_data=", data)

data = data.to(device)

print("data=", data)

'''
for batch_idx, (data, _) in enumerate(cifar_training_loader):
    print("Batch Index:", batch_idx)
    # print("Data:", data)
    print("Label:", _)
    print()
'''

Files already downloaded and verified
<torch.utils.data.dataloader.DataLoader object at 0x7af835315120>
<torch.utils.data.dataloader._SingleProcessDataLoaderIter object at 0x7af83fb9af50>
<class 'list'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
ori_data= tensor([[[[-0.1863, -0.1824, -0.1706,  ..., -0.1863, -0.1980, -0.2098],
          [-0.1627, -0.1549, -0.1431,  ..., -0.1745, -0.1863, -0.1980],
          [-0.1510, -0.1431, -0.1314,  ..., -0.1667, -0.1784, -0.1863],
          ...,
          [-0.1471, -0.1431, -0.1353,  ..., -0.1471, -0.1588, -0.1706],
          [-0.1353, -0.1314, -0.1275,  ..., -0.1510, -0.1588, -0.1706],
          [-0.1510, -0.1431, -0.1275,  ..., -0.1510, -0.1588, -0.1667]],

         [[-0.0647, -0.0569, -0.0451,  ..., -0.0843, -0.0961, -0.1078],
          [-0.0490, -0.0412, -0.0255,  ..., -0.0725, -0.0843, -0.0961],
          [-0.0412, -0.0333, -0.0216,  ..., -0.0647, -0.0765, -0.0843],
          ...,
          [-0.0294, -0.0255, -0.0176,  ..., -0.0529, -0.0647,

'\nfor batch_idx, (data, _) in enumerate(cifar_training_loader):\n    print("Batch Index:", batch_idx)\n    # print("Data:", data)\n    print("Label:", _)\n    print()\n'

In [None]:
class Model(nn.Module):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
                 num_embeddings, embedding_dim, commitment_cost, decay=0):
        super(Model, self).__init__()

        # Updated input size to [1, 64]
        self._encoder = Encoder(1, num_hiddens,
                                num_residual_layers,
                                num_residual_hiddens)
        print(self._encoder)

        # Replaced with nn.Linear
        self._pre_vq_linear = nn.Linear(num_hiddens, embedding_dim)

        print(self._pre_vq_linear)

        if decay > 0.0:
            self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
                                              commitment_cost, decay)
        else:
            self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)

        print(self._vq_vae)

        self._decoder = Decoder(embedding_dim,
                                num_hiddens,
                                num_residual_layers,
                                num_residual_hiddens)
        print(self._decoder)

    def forward(self, x):
        # x = x.unsqueeze(0)  # Add an extra dimension for batch size
        ## print("x=", x)
        ## print(x.shape)
        z = self._encoder(x)
        z = z.view(z.size(0), -1)  # Flatten the tensor
        ## print("z0=", z)
        # Replaced self._pre_vq_conv with self._pre_vq_linear
        z = self._pre_vq_linear(z)
        ## print("z1=", z)

        loss, quantized, perplexity, _ = self._vq_vae(z)
        ## print("quantized:", quantized)

        # # Reshape quantized before passing it to the decoder
        # quantized = quantized.unsqueeze(2).unsqueeze(3)
        # print("quantized_after:", quantized)
        x_recon = self._decoder(quantized)

        # x_recon = 0
        return loss, x_recon, perplexity

In [None]:
model = Model(num_hiddens, num_residual_layers, num_residual_hiddens,
              num_embeddings, embedding_dim,
              commitment_cost, decay).to(device)

Encoder(
  (_linear_1): Linear(in_features=64, out_features=64, bias=True)
  (_linear_2): Linear(in_features=64, out_features=128, bias=True)
  (_residual_stack): ResidualStack(
    (_layers): ModuleList(
      (0): Residual(
        (_block): Sequential(
          (0): ReLU(inplace=True)
          (1): Conv1d(128, 1, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
    )
  )
)
Linear(in_features=128, out_features=64, bias=True)
VectorQuantizerEMA(
  (_embedding): Embedding(512, 64)
)
Decoder(
  (_linear_1): Linear(in_features=64, out_features=128, bias=True)
  (_residual_stack): ResidualStack(
    (_layers): ModuleList(
      (0): Residual(
        (_block): Sequential(
          (0): ReLU(inplace=True)
          (1): Conv1d(128, 1, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
    )
  )
  (_linear_2): Linear(in_features=128, out_features=64, bias=True)
)


In [None]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=False)

In [None]:
model.train()
train_res_recon_error = []
train_res_perplexity = []

for i in xrange(num_training_updates):
    # (data, _) = next(iter(training_loader))
    # data = data.to(device)
    data = next(iter(training_loader))
    data = torch.stack(data).to(device)
    #for i in range(len(data)):
    #  data[i] = data[i].to(device)
    optimizer.zero_grad()

    print(data.shape)
    vq_loss, data_recon, perplexity = model(data)
    recon_error = F.mse_loss(data_recon, data) / data_variance
    torch.autograd.set_detect_anomaly(True) # for debug
    loss = recon_error + vq_loss
    loss.backward()

    optimizer.step()

    train_res_recon_error.append(recon_error.item())
    train_res_perplexity.append(perplexity.item())

    if (i+1) % 100 == 0:
        print('%d iterations' % (i+1))
        print('recon_error: %.3f' % np.mean(train_res_recon_error[-100:]))
        % print('perplexity: %.3f' % np.mean(train_res_perplexity[-100:]))
        print()

torch.Size([1, 64])
x= torch.Size([1, 64])
x= torch.Size([1, 128])


  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.10/dist-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 992, in launch_instance
    app.start()
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 619, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.10/asyncio/base_events.py", line 600, in run_forever
    self._run_once()
  File "/usr/lib/python3.10/asyncio/base_events.py", line 1896, in _run_once
    handle._run()
  File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callb

RuntimeError: ignored


*The reason of the error:*

When using the ReLU activation function, if the input of a neuron is negative, it causes the neuron to output zero constantly, resulting in deactivation. Since the gradient is zero at this point, it cannot recover.



## Plot Loss

In [None]:
train_res_recon_error_smooth = savgol_filter(train_res_recon_error, 201, 7)
train_res_perplexity_smooth = savgol_filter(train_res_perplexity, 201, 7)

In [None]:
f = plt.figure(figsize=(16,8))
ax = f.add_subplot(1,2,1)

ax.plot(train_res_recon_error_smooth)
ax.set_yscale('log')
ax.set_title('Smoothed NMSE.')
ax.set_xlabel('iteration')
'''
ax = f.add_subplot(1,2,2)
ax.plot(train_res_perplexity_smooth)
ax.set_title('Smoothed Average codebook usage (perplexity).')
ax.set_xlabel('iteration')
'''

## View Reconstructions

In [None]:
model.eval()

data = next(iter(training_loader))
train_originals = torch.stack(data).to(device)
vq_output_eval = model._pre_vq_linear(model._encoder(train_originals))
print("vq_output_eval=", vq_output_eval)
print("vq_output_eval.shape=", vq_output_eval.shape)
_, valid_quantize, _, _ = model._vq_vae(vq_output_eval)
print(valid_quantize)
print("valid_quantize=", valid_quantize)
print("valid_quantize.shape=", valid_quantize.shape)
# (train_originals, _) = next(iter(training_loader))
# train_originals = train_originals.to(device)
valid_reconstructions = model._decoder(valid_quantize)

In [None]:
def show(img):
    npimg = img.numpy()
    fig = plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)

In [None]:
show(make_grid(valid_reconstructions.cpu().data)+0.5, )

In [None]:
show(make_grid(train_originals.cpu()+0.5))

## View Embedding

In [None]:
! pip uninstall umap
! pip install umap-learn

import umap.umap_ as umap

proj = umap.UMAP(n_neighbors=3,
                 min_dist=0.1,
                 metric='cosine').fit_transform(model._vq_vae._embedding.weight.data.cpu())

In [None]:
plt.scatter(proj[:,0], proj[:,1], alpha=0.3)

# JAX version

In [8]:
!pip install dm-haiku==0.0.9

Collecting dm-haiku==0.0.9
  Downloading dm_haiku-0.0.9-py3-none-any.whl (352 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m352.1/352.1 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
Collecting jmp>=0.0.2 (from dm-haiku==0.0.9)
  Downloading jmp-0.0.4-py3-none-any.whl (18 kB)
Installing collected packages: jmp, dm-haiku
Successfully installed dm-haiku-0.0.9 jmp-0.0.4


In [9]:
"""Haiku implementation of VQ-VAE https://arxiv.org/abs/1711.00937."""

from typing import Any, Optional

from haiku._src import base
from haiku._src import initializers
from haiku._src import module
from haiku._src import moving_averages

import jax
import jax.numpy as jnp


# If you are forking replace this with `import haiku as hk`.
# pylint: disable=invalid-name
class hk:
  get_parameter = base.get_parameter
  get_state = base.get_state
  set_state = base.set_state
  initializers = initializers
  ExponentialMovingAverage = moving_averages.ExponentialMovingAverage
  Module = module.Module
# pylint: enable=invalid-name
del base, initializers, module, moving_averages


class VectorQuantizer(hk.Module):
  """Haiku module representing the VQ-VAE layer.

  Implements the algorithm presented in
  "Neural Discrete Representation Learning" by van den Oord et al.
  https://arxiv.org/abs/1711.00937

  Input any tensor to be quantized. Last dimension will be used as space in
  which to quantize. All other dimensions will be flattened and will be seen
  as different examples to quantize.

  The output tensor will have the same shape as the input.

  For example a tensor with shape ``[16, 32, 32, 64]`` will be reshaped into
  ``[16384, 64]`` and all ``16384`` vectors (each of ``64`` dimensions)  will be
  quantized independently.

  Attributes:
    embedding_dim: integer representing the dimensionality of the tensors in the
      quantized space. Inputs to the modules must be in this format as well.
    num_embeddings: integer, the number of vectors in the quantized space.
    commitment_cost: scalar which controls the weighting of the loss terms (see
      equation 4 in the paper - this variable is Beta).
  """

  def __init__(
      self,
      embedding_dim: int,
      num_embeddings: int,
      commitment_cost: float,
      dtype: Any = jnp.float32,
      name: Optional[str] = None,
      cross_replica_axis: Optional[str] = None,
  ):
    """Initializes a VQ-VAE module.

    Args:
      embedding_dim: dimensionality of the tensors in the quantized space.
        Inputs to the modules must be in this format as well.
      num_embeddings: number of vectors in the quantized space.
      commitment_cost: scalar which controls the weighting of the loss terms
        (see equation 4 in the paper - this variable is Beta).
      dtype: dtype for the embeddings variable, defaults to ``float32``.
      name: name of the module.
      cross_replica_axis: If not ``None``, it should be a string representing
        the axis name over which this module is being run within a
        :func:`jax.pmap`. Supplying this argument means that perplexity is
        calculated across all replicas on that axis.
    """
    super().__init__(name=name)
    self.embedding_dim = embedding_dim
    self.num_embeddings = num_embeddings
    self.commitment_cost = commitment_cost
    self.cross_replica_axis = cross_replica_axis

    self._embedding_shape = [embedding_dim, num_embeddings]
    self._embedding_dtype = dtype

  @property
  def embeddings(self):
    initializer = hk.initializers.VarianceScaling(distribution="uniform")
    return hk.get_parameter(
        "embeddings",
        self._embedding_shape,
        self._embedding_dtype,
        init=initializer)

  def __call__(self, inputs, is_training):
    """Connects the module to some inputs.

    Args:
      inputs: Tensor, final dimension must be equal to ``embedding_dim``. All
        other leading dimensions will be flattened and treated as a large batch.
      is_training: boolean, whether this connection is to training data.

    Returns:
      dict: Dictionary containing the following keys and values:
        * ``quantize``: Tensor containing the quantized version of the input.
        * ``loss``: Tensor containing the loss to optimize.
        * ``perplexity``: Tensor containing the perplexity of the encodings.
        * ``encodings``: Tensor containing the discrete encodings, ie which
          element of the quantized space each input element was mapped to.
        * ``encoding_indices``: Tensor containing the discrete encoding indices,
          ie which element of the quantized space each input element was mapped
          to.
    """
    flat_inputs = jnp.reshape(inputs, [-1, self.embedding_dim])

    distances = (
        jnp.sum(jnp.square(flat_inputs), 1, keepdims=True) -
        2 * jnp.matmul(flat_inputs, self.embeddings) +
        jnp.sum(jnp.square(self.embeddings), 0, keepdims=True))

    encoding_indices = jnp.argmax(-distances, 1)
    encodings = jax.nn.one_hot(encoding_indices,
                               self.num_embeddings,
                               dtype=distances.dtype)

    # NB: if your code crashes with a reshape error on the line below about a
    # Tensor containing the wrong number of values, then the most likely cause
    # is that the input passed in does not have a final dimension equal to
    # self.embedding_dim. Ideally we would catch this with an Assert but that
    # creates various other problems related to device placement / TPUs.
    encoding_indices = jnp.reshape(encoding_indices, inputs.shape[:-1])
    quantized = self.quantize(encoding_indices)

    e_latent_loss = jnp.mean(
        jnp.square(jax.lax.stop_gradient(quantized) - inputs))
    q_latent_loss = jnp.mean(
        jnp.square(quantized - jax.lax.stop_gradient(inputs)))
    loss = q_latent_loss + self.commitment_cost * e_latent_loss

    # Straight Through Estimator
    quantized = inputs + jax.lax.stop_gradient(quantized - inputs)
    avg_probs = jnp.mean(encodings, 0)
    if self.cross_replica_axis:
      avg_probs = jax.lax.pmean(avg_probs, axis_name=self.cross_replica_axis)
    perplexity = jnp.exp(-jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10)))

    return {
        "quantize": quantized,
        "loss": loss,
        "perplexity": perplexity,
        "encodings": encodings,
        "encoding_indices": encoding_indices,
        "distances": distances,
    }

  def quantize(self, encoding_indices):
    """Returns embedding tensor for a batch of indices."""
    w = self.embeddings.swapaxes(1, 0)
    w = jax.device_put(w)  # Required when embeddings is a NumPy array.
    return w[(encoding_indices,)]


class VectorQuantizerEMA(hk.Module):
  r"""Haiku module representing the VQ-VAE layer.

  Implements a slightly modified version of the algorithm presented in
  "Neural Discrete Representation Learning" by van den Oord et al.
  https://arxiv.org/abs/1711.00937

  The difference between :class:`VectorQuantizerEMA` and
  :class:`VectorQuantizer` is that this module uses
  :class:`~haiku.ExponentialMovingAverage`\ s to update the embedding vectors
  instead of an auxiliary loss. This has the advantage that the embedding
  updates are independent of the choice of optimizer (SGD, RMSProp, Adam, K-Fac,
  ...) used for the encoder, decoder and other parts of the architecture. For
  most experiments the EMA version trains faster than the non-EMA version.

  Input any tensor to be quantized. Last dimension will be used as space in
  which to quantize. All other dimensions will be flattened and will be seen
  as different examples to quantize.

  The output tensor will have the same shape as the input.

  For example a tensor with shape ``[16, 32, 32, 64]`` will be reshaped into
  ``[16384, 64]`` and all ``16384`` vectors (each of 64 dimensions)  will be
  quantized independently.

  Attributes:
    embedding_dim: integer representing the dimensionality of the tensors in
      the quantized space. Inputs to the modules must be in this format as well.
    num_embeddings: integer, the number of vectors in the quantized space.
    commitment_cost: scalar which controls the weighting of the loss terms
      (see equation 4 in the paper).
    decay: float, decay for the moving averages.
    epsilon: small float constant to avoid numerical instability.
  """

  def __init__(
      self,
      embedding_dim,
      num_embeddings,
      commitment_cost,
      decay,
      epsilon: float = 1e-5,
      dtype: Any = jnp.float32,
      cross_replica_axis: Optional[str] = None,
      name: Optional[str] = None,
  ):
    """Initializes a VQ-VAE EMA module.

    Args:
      embedding_dim: integer representing the dimensionality of the tensors in
        the quantized space. Inputs to the modules must be in this format as
        well.
      num_embeddings: integer, the number of vectors in the quantized space.
      commitment_cost: scalar which controls the weighting of the loss terms
        (see equation 4 in the paper - this variable is Beta).
      decay: float between 0 and 1, controls the speed of the Exponential Moving
        Averages.
      epsilon: small constant to aid numerical stability, default ``1e-5``.
      dtype: dtype for the embeddings variable, defaults to ``float32``.
      cross_replica_axis: If not ``None``, it should be a string representing
        the axis name over which this module is being run within a
        :func:`jax.pmap`. Supplying this argument means that cluster statistics
        and the perplexity are calculated across all replicas on that axis.
      name: name of the module.
    """
    super().__init__(name=name)
    if not 0 <= decay <= 1:
      raise ValueError("decay must be in range [0, 1]")

    self.embedding_dim = embedding_dim
    self.num_embeddings = num_embeddings
    self.decay = decay
    self.commitment_cost = commitment_cost
    self.epsilon = epsilon
    self.cross_replica_axis = cross_replica_axis

    self._embedding_shape = [embedding_dim, num_embeddings]
    self._dtype = dtype

    self._ema_cluster_size = hk.ExponentialMovingAverage(
        decay=self.decay, name="ema_cluster_size")
    self._ema_dw = hk.ExponentialMovingAverage(decay=self.decay, name="ema_dw")

  @property
  def embeddings(self):
    initializer = hk.initializers.VarianceScaling(distribution="uniform")
    return hk.get_state(
        "embeddings", self._embedding_shape, self._dtype, init=initializer)

  @property
  def ema_cluster_size(self):
    self._ema_cluster_size.initialize([self.num_embeddings], self._dtype)
    return self._ema_cluster_size

  @property
  def ema_dw(self):
    self._ema_dw.initialize(self._embedding_shape, self._dtype)
    return self._ema_dw

  def __call__(self, inputs, is_training):
    """Connects the module to some inputs.

    Args:
      inputs: Tensor, final dimension must be equal to ``embedding_dim``. All
        other leading dimensions will be flattened and treated as a large batch.
      is_training: boolean, whether this connection is to training data. When
        this is set to ``False``, the internal moving average statistics will
        not be updated.

    Returns:
      dict: Dictionary containing the following keys and values:
        * ``quantize``: Tensor containing the quantized version of the input.
        * ``loss``: Tensor containing the loss to optimize.
        * ``perplexity``: Tensor containing the perplexity of the encodings.
        * ``encodings``: Tensor containing the discrete encodings, ie which
          element of the quantized space each input element was mapped to.
        * ``encoding_indices``: Tensor containing the discrete encoding indices,
          ie which element of the quantized space each input element was mapped
          to.
    """
    flat_inputs = jnp.reshape(inputs, [-1, self.embedding_dim])
    embeddings = self.embeddings

    distances = (
        jnp.sum(jnp.square(flat_inputs), 1, keepdims=True) -
        2 * jnp.matmul(flat_inputs, embeddings) +
        jnp.sum(jnp.square(embeddings), 0, keepdims=True))

    encoding_indices = jnp.argmax(-distances, 1)
    encodings = jax.nn.one_hot(encoding_indices,
                               self.num_embeddings,
                               dtype=distances.dtype)

    # NB: if your code crashes with a reshape error on the line below about a
    # Tensor containing the wrong number of values, then the most likely cause
    # is that the input passed in does not have a final dimension equal to
    # self.embedding_dim. Ideally we would catch this with an Assert but that
    # creates various other problems related to device placement / TPUs.
    encoding_indices = jnp.reshape(encoding_indices, inputs.shape[:-1])
    quantized = self.quantize(encoding_indices)
    e_latent_loss = jnp.mean(
        jnp.square(jax.lax.stop_gradient(quantized) - inputs))

    if is_training:
      cluster_size = jnp.sum(encodings, axis=0)
      if self.cross_replica_axis:
        cluster_size = jax.lax.psum(
            cluster_size, axis_name=self.cross_replica_axis)
      updated_ema_cluster_size = self.ema_cluster_size(cluster_size)

      dw = jnp.matmul(flat_inputs.T, encodings)
      if self.cross_replica_axis:
        dw = jax.lax.psum(dw, axis_name=self.cross_replica_axis)
      updated_ema_dw = self.ema_dw(dw)

      n = jnp.sum(updated_ema_cluster_size)
      updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon) /
                                  (n + self.num_embeddings * self.epsilon) * n)

      normalised_updated_ema_w = (
          updated_ema_dw / jnp.reshape(updated_ema_cluster_size, [1, -1]))

      hk.set_state("embeddings", normalised_updated_ema_w)
      loss = self.commitment_cost * e_latent_loss

    else:
      loss = self.commitment_cost * e_latent_loss

    # Straight Through Estimator
    quantized = inputs + jax.lax.stop_gradient(quantized - inputs)
    avg_probs = jnp.mean(encodings, 0)
    if self.cross_replica_axis:
      avg_probs = jax.lax.pmean(avg_probs, axis_name=self.cross_replica_axis)
    perplexity = jnp.exp(-jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10)))

    return {
        "quantize": quantized,
        "loss": loss,
        "perplexity": perplexity,
        "encodings": encodings,
        "encoding_indices": encoding_indices,
        "distances": distances,
    }

  def quantize(self, encoding_indices):
    """Returns embedding tensor for a batch of indices."""
    w = self.embeddings.swapaxes(1, 0)
    w = jax.device_put(w)  # Required when embeddings is a NumPy array.
    return w[(encoding_indices,)]

In [None]:
import jax
import jax.numpy as jnp
from flax import linen as nn

class Residual(nn.Module):
    in_channels: int
    num_hiddens: int
    num_residual_hiddens: int

    def setup(self):
        self._block = nn.Sequential([
            nn.relu,
            nn.Conv(features=self.num_residual_hiddens,
                    kernel_size=(1,),
                    strides=(1,),
                    use_bias=False)
        ])

    def __call__(self, x):
        x = jnp.expand_dims(x, 2)  # Add an extra dimension
        output = x + self._block(x)
        return jnp.squeeze(output, 2)  # Remove the extra dimension

class ResidualStack(nn.Module):
    in_channels: int
    num_hiddens: int
    num_residual_layers: int
    num_residual_hiddens: int

    def setup(self):
        self._layers = [
            Residual(self.in_channels, self.num_hiddens, self.num_residual_hiddens)
            for _ in range(self.num_residual_layers)
        ]

    def __call__(self, x):
        for layer in self._layers:
            x = layer(x)
        return nn.relu(x)

class Encoder(nn.Module):
    in_channels: int
    num_hiddens: int
    num_residual_layers: int
    num_residual_hiddens: int

    def setup(self):
        self._linear_1 = nn.Dense(self.num_hiddens // 2)
        self._linear_2 = nn.Dense(self.num_hiddens)
        self._residual_stack = ResidualStack(self.in_channels, self.num_hiddens,
                                             self.num_residual_layers, self.num_residual_hiddens)

    def __call__(self, inputs):
        x = self._linear_1(inputs.reshape((inputs.shape[0], -1)))
        x = nn.relu(x)
        x = self._linear_2(x)
        x = nn.relu(x)
        return self._residual_stack(x)

class Decoder(nn.Module):
    in_channels: int
    num_hiddens: int
    num_residual_layers: int
    num_residual_hiddens: int

    def setup(self):
        self._linear_1 = nn.Dense(self.num_hiddens)
        self._residual_stack = ResidualStack(self.in_channels, self.num_hiddens,
                                             self.num_residual_layers, self.num_residual_hiddens)
        self._linear_2 = nn.Dense(self.num_hiddens // 2)

    def __call__(self, inputs):
        x = self._linear_1(inputs.reshape((inputs.shape[0], -1)))
        x = self._residual_stack(x)
        x = self._linear_2(x)
        x = nn.relu(x)
        return x

In [11]:
import sys
print(sys.version)

3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]


In [13]:
!curl https://pyenv.run | bash
import os
os.environ['PATH'] += ":/root/.pyenv/bin"
!pyenv install 3.8.0
!pyenv global 3.8.0

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   270  100   270    0     0   1357      0 --:--:-- --:--:-- --:--:--  1356
Cloning into '/root/.pyenv'...
remote: Enumerating objects: 1162, done.[K
remote: Counting objects: 100% (1162/1162), done.[K
remote: Compressing objects: 100% (666/666), done.[K
remote: Total 1162 (delta 674), reused 633 (delta 363), pack-reused 0[K
Receiving objects: 100% (1162/1162), 578.51 KiB | 13.45 MiB/s, done.
Resolving deltas: 100% (674/674), done.
Cloning into '/root/.pyenv/plugins/pyenv-doctor'...
remote: Enumerating objects: 11, done.[K
remote: Counting objects: 100% (11/11), done.[K
remote: Compressing objects: 100% (9/9), done.[K
remote: Total 11 (delta 1), reused 5 (delta 0), pack-reused 0[K
Receiving objects: 100% (11/11), 38.72 KiB | 9.68 MiB/s, done.
Resolving deltas: 100% (1/1), done.
Cloning into '/root/.pyenv/plugins/pyenv-u

In [18]:
!pyenv global 3.8.0

In [19]:
import sys
print(sys.version)

3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]


In [14]:
"""Tests for haiku._src.nets.vqvae."""

import functools

from absl.testing import absltest
from absl.testing import parameterized

from haiku._src import stateful
from haiku._src import test_utils
from haiku._src import transform
from haiku._src.nets import vqvae
import jax
import jax.numpy as jnp
import numpy as np


class VqvaeTest(parameterized.TestCase):

  @parameterized.parameters((vqvae.VectorQuantizer, {
      'embedding_dim': 4,
      'num_embeddings': 8,
      'commitment_cost': 0.25
  }), (vqvae.VectorQuantizerEMA, {
      'embedding_dim': 6,
      'num_embeddings': 13,
      'commitment_cost': 0.5,
      'decay': 0.1
  }))
  @test_utils.transform_and_run
  def testConstruct(self, constructor, kwargs):
    vqvae_module = constructor(**kwargs)
    # Batch of input vectors to quantize
    inputs_np = np.random.randn(100, kwargs['embedding_dim']).astype(np.float32)
    inputs = jnp.array(inputs_np)

    # Set is_training to False, otherwise for the EMA case just evaluating the
    # forward pass will change the embeddings, meaning that some of our computed
    # closest embeddings will be incorrect.
    vq_output = vqvae_module(inputs, is_training=False)

    # Output shape is correct
    self.assertEqual(vq_output['quantize'].shape, inputs.shape)

    vq_output_np = jax.tree_util.tree_map(lambda t: t, vq_output)
    embeddings_np = vqvae_module.embeddings

    self.assertEqual(embeddings_np.shape,
                     (kwargs['embedding_dim'], kwargs['num_embeddings']))

    # Check that each input was assigned to the embedding it is closest to.
    distances = (jnp.square(inputs_np).sum(axis=1, keepdims=True) -
                 2 * np.dot(inputs_np, embeddings_np) +
                 jnp.square(embeddings_np).sum(axis=0, keepdims=True))
    closest_index = np.argmax(-distances, axis=1)
    # On TPU, distances can be different by ~1% due to precision. This can cause
    # the distanc to the closest embedding to flip, leading to a difference
    # in the encoding indices tensor. First we check that the continuous
    # distances are reasonably close, and then we only allow N differences in
    # the encodings. For batch of 100, N == 3 seems okay (passed 1000x tests).
    np.testing.assert_allclose(distances, vq_output_np['distances'], atol=5e-2)
    num_differences_in_encodings = (closest_index !=
                                    vq_output_np['encoding_indices']).sum()
    num_differences_allowed = 3
    self.assertLessEqual(num_differences_in_encodings, num_differences_allowed)

  @parameterized.parameters((vqvae.VectorQuantizer, {
      'embedding_dim': 4,
      'num_embeddings': 8,
      'commitment_cost': 0.25
  }), (vqvae.VectorQuantizerEMA, {
      'embedding_dim': 6,
      'num_embeddings': 13,
      'commitment_cost': 0.5,
      'decay': 0.1
  }))
  @test_utils.transform_and_run
  def testShapeChecking(self, constructor, kwargs):
    vqvae_module = constructor(**kwargs)
    wrong_shape_input = np.random.randn(100, kwargs['embedding_dim'] * 2)
    with self.assertRaisesRegex(TypeError, 'total size must be unchanged'):
      vqvae_module(
          jnp.array(wrong_shape_input.astype(np.float32)), is_training=False)

  @parameterized.parameters((vqvae.VectorQuantizer, {
      'embedding_dim': 4,
      'num_embeddings': 8,
      'commitment_cost': 0.25
  }), (vqvae.VectorQuantizerEMA, {
      'embedding_dim': 6,
      'num_embeddings': 13,
      'commitment_cost': 0.5,
      'decay': 0.1
  }))
  @test_utils.transform_and_run
  def testNoneBatch(self, constructor, kwargs):
    """Check that vqvae can be built on input with a None batch dimension."""
    vqvae_module = constructor(**kwargs)
    inputs = jnp.zeros([0, 5, 5, kwargs['embedding_dim']])
    vqvae_module(inputs, is_training=False)

  @parameterized.parameters({'use_jit': True, 'dtype': jnp.float32},
                            {'use_jit': True, 'dtype': jnp.float64},
                            {'use_jit': False, 'dtype': jnp.float32},
                            {'use_jit': False, 'dtype': jnp.float64})
  @test_utils.transform_and_run
  def testEmaUpdating(self, use_jit, dtype):
    if jax.local_devices()[0].platform == 'tpu' and dtype == jnp.float64:
      self.skipTest('F64 not supported by TPU')

    embedding_dim = 6
    np_dtype = np.float64 if dtype is jnp.float64 else np.float32
    decay = np.array(0.1, dtype=np_dtype)
    vqvae_module = vqvae.VectorQuantizerEMA(
        embedding_dim=embedding_dim,
        num_embeddings=7,
        commitment_cost=0.5,
        decay=decay,
        dtype=dtype)

    if use_jit:
      vqvae_f = stateful.jit(vqvae_module, static_argnums=1)
    else:
      vqvae_f = vqvae_module

    batch_size = 16

    prev_embeddings = vqvae_module.embeddings

    # Embeddings should change with every forwards pass if is_training == True.
    for _ in range(10):
      inputs = np.random.rand(batch_size, embedding_dim).astype(dtype)
      vqvae_f(inputs, True)
      current_embeddings = vqvae_module.embeddings
      self.assertFalse((prev_embeddings == current_embeddings).all())
      prev_embeddings = current_embeddings

    # Forward passes with is_training == False don't change anything
    for _ in range(10):
      inputs = np.random.rand(batch_size, embedding_dim).astype(dtype)
      vqvae_f(inputs, False)
      current_embeddings = vqvae_module.embeddings
      self.assertTrue((current_embeddings == prev_embeddings).all())

  def testEmaCrossReplica(self):
    embedding_dim = 6
    batch_size = 16
    inputs = np.random.rand(jax.local_device_count(), batch_size, embedding_dim)
    embeddings = {}
    perplexities = {}

    for axis_name in [None, 'i']:
      def my_function(x, axis_name):
        decay = np.array(0.9, dtype=np.float32)
        vqvae_module = vqvae.VectorQuantizerEMA(
            embedding_dim=embedding_dim,
            num_embeddings=7,
            commitment_cost=0.5,
            decay=decay,
            cross_replica_axis=axis_name,
            dtype=jnp.float32)

        outputs = vqvae_module(x, is_training=True)
        return vqvae_module.embeddings, outputs['perplexity']

      vqvae_f = transform.transform_with_state(
          functools.partial(my_function, axis_name=axis_name))

      rng = jax.random.PRNGKey(42)
      rng = jnp.broadcast_to(rng, (jax.local_device_count(), *rng.shape))

      params, state = jax.pmap(
          vqvae_f.init, axis_name='i')(rng, inputs)
      update_fn = jax.pmap(vqvae_f.apply, axis_name='i')

      for _ in range(10):
        outputs, state = update_fn(params, state, None, inputs)
      embeddings[axis_name], perplexities[axis_name] = outputs

    # In the single-device case, specifying a cross_replica_axis should have
    # no effect. Otherwise, it should!
    if jax.device_count() == 1:
      # Have to use assert_allclose here rather than checking exact matches to
      # make the test pass on GPU, presumably because of nondeterministic
      # reductions.
      np.testing.assert_allclose(
          embeddings[None], embeddings['i'], rtol=1e-6, atol=1e-6)
      np.testing.assert_allclose(
          perplexities[None], perplexities['i'], rtol=1e-6, atol=1e-6)
    else:
      self.assertFalse((embeddings[None] == embeddings['i']).all())
      self.assertFalse((perplexities[None] == perplexities['i']).all())


if __name__ == '__main__':
  absltest.main()

Running tests under Python 3.10.12: /usr/bin/python3
FATAL Flags parsing error: Unknown command line flag 'f'
Pass --helpshort or --helpfull to see help on flags.
E0829 03:38:20.981939 138459361918976 ultratb.py:152] Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 156, in parse_flags_with_usage
    return FLAGS(args)
  File "/usr/local/lib/python3.10/dist-packages/absl/flags/_flagvalues.py", line 652, in __call__
    raise _exceptions.UnrecognizedFlagError(
absl.flags._exceptions.UnrecognizedFlagError: Unknown command line flag 'f'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-14-323b34bb46c9>", line 195, in <cell line: 194>
    absltest.main()
  File "/usr/local/lib/python3.10/dist-packages/absl/testing/absltest.py", line 2060, in main
    _run_in_app(run_tests, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/absl/testing/absltest.py", line 2165, in _run_in_app
    app.run(main=main_function)


TypeError: ignored