# Import modules

In [1]:
import os, h5py, torch, torchvision
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from einops import rearrange

# Hyperparameters

In [2]:
# NUM_LATENT_K = 20                 # Number of codebook entries
# NUM_LATENT_D = 64                 # Dimension of each codebook entries
# BETA = 1.0                        # Weight for the commitment loss

# INPUT_SHAPE = x_train.shape[1:]
# SIZE = None                       # Spatial size of latent embedding
#                                   # will be set dynamically in `build_vqvae

# VQVAE_BATCH_SIZE = 128            # Batch size for training the VQVAE
# VQVAE_NUM_EPOCHS = 20             # Number of epochs
# VQVAE_LEARNING_RATE = 3e-4        # Learning rate
# VQVAE_LAYERS = [16, 32]           # Number of filters for each layer in the encoder

# PIXELCNN_BATCH_SIZE = 128         # Batch size for training the PixelCNN prior
# PIXELCNN_NUM_EPOCHS = 10          # Number of epochs
# PIXELCNN_LEARNING_RATE = 3e-4     # Learning rate
# PIXELCNN_NUM_BLOCKS = 12          # Number of Gated PixelCNN blocks in the architecture
# PIXELCNN_NUM_FEATURE_MAPS = 32    # Width of each PixelCNN block

# Load Data

In [3]:
class LoadData(Dataset):
    def __init__(self, data_path, indices, transform=None):
        self.data_path = data_path
        self.transform = transform
        assert os.path.exists(self.data_path), "images path {} does not exist".format(self.data_path)
        self.data = h5py.File(self.data_path, 'r')
        self.images = self.data['images'][indices]  # array shape [480000,64,64,3], uint8 in range(256)
        self.labels = self.data['labels'][indices]  # array shape [480000,6], float64

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, index):
        image = self.images[index]
        label = self.labels[index]
        
        # Normalize the image to range [-1, 1]
        image = 2*(image.astype(np.float32)/255.0) - 1
        # Transpose image to fit PyTorch's [C, H, W] format
        image = np.transpose(image, (2, 0, 1))
        
        image = torch.tensor(image, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.float32)
        return image, label


# VQVAE Class

In [12]:
class VQVAE(nn.Module):
    def __init__(self):
        super().__init__() #super(VQVAE, self) is a python2 style for running super and call __init__() method of the parent class. In python3 we can use this new typping style.
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 24, 4, stride=2, padding=1), # inp_channels=1, out_channels=16, kernel_size (size of receptive field)=4 => width & height output = 14. We have also 16 Weight matrices that result in having an output of shape [14,14,16]
            nn.BatchNorm2d(24),
            nn.ReLU(),
            nn.Conv2d(24, 8, 4, stride=2, padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(),
        )
        
        self.pre_quant_conv = nn.Conv2d(8, 4, kernel_size=1) # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
        self.embedding = nn.Embedding(num_embeddings=5, embedding_dim=4) # Three codebook vectors of size two - For more info regarding Embedding documentation: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
        self.post_quant_conv = nn.Conv2d(4, 24, kernel_size=1)
        
        # Commitment Loss Beta
        self.beta = 0.2
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(24, 6, 4, stride=2, padding=1),
            nn.BatchNorm2d(6),
            nn.ReLU(),
            nn.ConvTranspose2d(6, 3, 4, stride=2, padding=1),
            nn.Tanh(),
        )
        
        
    def forward(self, x):
        # B, C, H, W
        encoded_output = self.encoder(x)
        quant_input = self.pre_quant_conv(encoded_output)
        
        ## Quantization
        B, C, H, W = quant_input.shape
        quant_input = quant_input.permute(0, 2, 3, 1)
        quant_input = quant_input.reshape((quant_input.size(0), -1, quant_input.size(-1)))
        
        # Compute pairwise distances
        dist = torch.cdist(quant_input, self.embedding.weight[None, :].repeat((quant_input.size(0), 1, 1)))
        
        # Find index of nearest embedding
        min_encoding_indices = torch.argmin(dist, dim=-1)
        
        # Select the embedding weights
        quant_out = torch.index_select(self.embedding.weight, 0, min_encoding_indices.view(-1))
        quant_input = quant_input.reshape((-1, quant_input.size(-1)))
        
        # Compute losses
        commitment_loss = torch.mean((quant_out.detach() - quant_input)**2) # commitment loss pushes z_x toward embedding space and updates encoded space. It ensures that it does not grow faster than embedding space
        codebook_loss = torch.mean((quant_out - quant_input.detach())**2) # codebook loss pushes e_i toward z_x and update embedding space
        quantize_losses = codebook_loss + self.beta*commitment_loss
        
        # Ensure straight through gradient
        quant_out = quant_input + (quant_out - quant_input).detach()
        
        # Reshaping back to original input shape
        quant_out = quant_out.reshape((B, H, W, C)).permute(0, 3, 1, 2)
        min_encoding_indices = min_encoding_indices.reshape((-1, quant_out.size(-2), quant_out.size(-1)))
        
        
        ## Decoder part
        decoder_input = self.post_quant_conv(quant_out)
        output = self.decoder(decoder_input)
        return output, quantize_losses

# Sample & Split Data

In [5]:
def split_sample_data(n_samples=480_000, seed=42, n_sub_samples=None):
    np.random.seed(seed)
    if n_sub_samples:
        # Generate a list of n_sub_samples random integers between 0 and n_samples-1
        indices = np.random.randint(0, n_samples, size=n_sub_samples)
        # Calculate the number of test indices (20% of the list) and extract the test indices
        n_test = int(len(indices)*0.2)
        test_indices = np.random.choice(indices, size=n_test, replace=False)
        # Calculate the train indices (remaining 80% of the list)
        train_indices = [i for i in indices if i not in test_indices]
    else:
        # Generate an array between 0 and n_samples-1
        indices = np.arange(n_samples)
        # Calculate the number of test indices (20% of the list) and extract the test indices
        n_test = int(len(indices)*0.2)
        test_indices = np.random.choice(indices, size=n_test, replace=False)
        # Calculate the train indices (remaining 80% of the list)
        train_indices = [i for i in indices if i not in test_indices]
    train_indices.sort()
    test_indices = test_indices.tolist()
    test_indices.sort()
    return train_indices, test_indices

# VQVAE Implementation

In [13]:
def train_vqvae(train_indices, test_indices, path='/home/mohsen/Desktop/Academia/RUB Research Projects/INI/data/3dshapes/3dshapes.h5'):
    train = LoadData(data_path=path, indices=train_indices)
    test = LoadData(data_path=path, indices=test_indices)
    data_loader = DataLoader(train, batch_size=64, shuffle=True, num_workers=4)
    
    model = VQVAE().to(device)
    
    num_epochs = 20
    optimizer = Adam(model.parameters(), lr=1E-3)
    criterion = torch.nn.MSELoss()
    
    for epoch_idx in range(num_epochs):
        for im, label in data_loader:
            im = im.float().to(device)
            optimizer.zero_grad()
            out, quantize_loss = model(im)
            recon_loss = criterion(out, im)
            loss = recon_loss + quantize_loss
            loss.backward()
            optimizer.step()

        print(f'Recon Loss: {recon_loss.item()}, Quantize Loss: {quantize_loss.item()}, Total Loss: {loss.item()}')
        print('Finished epoch {}'.format(epoch_idx+1))
    print('Done Training...')
    
    # Reconstruction part
    
    idxs = torch.randint(0, len(test), (100, ))
    ims = torch.cat([test[idx][0][None, :] for idx in idxs]).float()
    ims = ims.to(device)
    model.eval()
    
    
    generated_im, _ = model(ims)
    ims = (ims+1)/2
    generated_im = 1 - (generated_im+1)/2
    out = torch.hstack([ims, generated_im])
    output = rearrange(out, 'b c h w -> b () h (c w)')
    grid = torchvision.utils.make_grid(output.detach().cpu(), nrow=10)
    img = torchvision.transforms.ToPILImage()(grid)
    img.save('reconstruction.png')
    # im.show()
    
    print('Done Reconstruction ...')


In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_indices, test_indices = split_sample_data(n_samples=480_000, seed=42, n_sub_samples=1000)
train_vqvae(train_indices, test_indices)

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Recon Loss: 0.667660653591156, Quantize Loss: 0.25083455443382263, Total Loss: 0.9184951782226562
Finished epoch 1
Recon Loss: 0.5990063548088074, Quantize Loss: 0.19450423121452332, Total Loss: 0.7935105562210083
Finished epoch 2
Recon Loss: 0.5863282680511475, Quantize Loss: 0.17511802911758423, Total Loss: 0.7614462971687317
Finished epoch 3
Recon Loss: 0.5011268854141235, Quantize Loss: 0.17984043061733246, Total Loss: 0.6809673309326172
Finished epoch 4
Recon Loss: 0.5054752230644226, Quantize Loss: 0.18088242411613464, Total Loss: 0.6863576173782349
Finished epoch 5
Recon Loss: 0.4698594808578491, Quantize Loss: 0.1983441710472107, Total Loss: 0.6682036519050598
Finished epoch 6
Recon Loss: 0.44251102209091187, Quantize Loss: 0.19866977632045746, Total Loss: 0.6411808133125305
Finished epoch 7
Recon Loss: 0.43299946188926697, Quantize Loss: 0.21844150125980377, Total Loss: 0.6514409780502319
Finished epoch 8
Recon Loss: 0.40997061133384705, Quantize Loss: 0.23678648471832275, Tot