In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
#Residual Connection
class Residual(nn.Module):
    def __init__(self, dim):
        super(Residual, self).__init__()
        self.model = nn.Sequential(nn.BatchNorm2d(dim),
                                  nn.ReLU(),
                                  nn.Conv2d(dim, dim, kernel_size = 3, stride = 1, padding = 1),
                                  nn.BatchNorm2d(dim),
                                  nn.ReLU(),
                                  nn.Conv2d(dim, dim, kernel_size = 1, stride = 1, padding = 0),)
    def forward(self, x):
        x = self.model(x) + x
        return x

In [3]:
class ResBlock(nn.Module):
    def __init__(self, num_of_blocks, dim):
        super(ResBlock, self).__init__()
        self._num_of_blocks = num_of_blocks
        self.model = nn.ModuleList([Residual(dim) for i in range(self._num_of_blocks)])

    def forward(self, x):
        for model in self.model:
            x = model(x)
        return(x)

In [4]:
#Encoder
'''
input -> NCHW
multiple output depending on the hierarchy
'''
class Encoder(nn.Module):
    def __init__(self, input_dim, out_dim, num_of_blocks, factor = 4):
        super(Encoder, self).__init__()
        assert factor in [2, 4], 'Factor has to be either 1 or 2'
        
        if factor == 4:
            self.model = nn.Sequential(
                nn.Conv2d(input_dim, out_dim, kernel_size = 4, stride = 2, padding = 1),
                nn.BatchNorm2d(out_dim),
                nn.ReLU(inplace = True),
                nn.Conv2d(out_dim, out_dim, kernel_size = 4, stride = 2, padding = 1),
            )
        
        elif factor == 2:
            self.model = nn.Sequential(
                 nn.Conv2d(input_dim, out_dim, kernel_size = 4, stride = 2, padding = 1),
            )
            
        self.res1 = ResBlock(num_of_blocks, out_dim)
        self.res2 = ResBlock(num_of_blocks, out_dim)
        
    def forward(self, x):
        x = self.model(x)
        #print(x_bottom.shape)
        x = self.res2(self.res1(x))
        #print(x.shape)
        return x

In [5]:
#Quantized
class Quantized(nn.Module):
    def __init__(self, num_embeddings, embed_dim, commitment_cost = 0.25):
        super(Quantized, self).__init__()
        self.num_embeddings = num_embeddings
        self.embed_dim = embed_dim
        
        self.embeddings = nn.Embedding(self.num_embeddings, self.embed_dim)
        self.embeddings.weight.data.uniform_(-1./self.num_embeddings, 1./self.num_embeddings)
        self._commitment_cost = commitment_cost
    def forward(self, x):
        '''
        input -> NCHW
        '''
        x = x.permute(0, 2, 3, 1).contiguous()
        input_shape = x.shape
        x_flat = x.reshape(-1, self.embed_dim)
        
        distances =  (torch.sum(x_flat ** 2, dim = 1, keepdim = True)
                     + torch.sum(self.embeddings.weight ** 2, dim = 1)
                     - 2 * (torch.matmul(x_flat, self.embeddings.weight.t())))

        #quantized = self.embeddings(torch.argmin(distance, dim = 1)).reshape(input_shape)
        
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=x.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self.embeddings.weight).view(input_shape)
        
        e_latent_loss = F.mse_loss(quantized.detach(), x)
        q_latent_loss = F.mse_loss(quantized, x.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss
        
        quantized = x + (quantized - x).detach()
        
        return loss, quantized.permute(0, 3, 1, 2).contiguous()

In [6]:
#Decoder
class Decoder(nn.Module):
    def __init__(self, input_dim, out_dim, num_of_blocks, factor = 4):
        super(Decoder, self).__init__()
        assert factor in [2, 4], 'Factor has to be either 2 or 4'
        
        self.res1 = ResBlock(num_of_blocks, input_dim)
        self.res2 = ResBlock(num_of_blocks, input_dim)
        
        if factor == 4:
            self.model = nn.Sequential(
                nn.BatchNorm2d(input_dim),
                nn.ReLU(inplace = True),
                nn.ConvTranspose2d(input_dim, input_dim, kernel_size = 4, stride = 2, padding = 1),
                nn.BatchNorm2d(input_dim),
                nn.ReLU(inplace = True),
                nn.ConvTranspose2d(input_dim, out_dim, kernel_size = 4, stride = 2, padding = 1),
            )
        
        elif factor == 2:
            self.model = nn.Sequential(
                nn.BatchNorm2d(input_dim),
                nn.ReLU(inplace = True),
                nn.ConvTranspose2d(input_dim, out_dim, kernel_size = 4, stride = 2, padding = 1),
            )
        
    def forward(self, x):
        
        x = self.res2(self.res1(x))
        x = self.model(x)
        return x

In [7]:
#VQ-VAE2
class VQVAE2(nn.Module):
    def __init__(self, input_dim, out_dim, num_of_blocks,num_embeddings, embed_dim):
        super(VQVAE2, self).__init__()
        self.enc_b = Encoder(input_dim, out_dim//2, num_of_blocks, factor = 4)
        self.enc_t = Encoder(out_dim//2, out_dim, num_of_blocks, factor = 2)
        
        self.conv1 = nn.Conv2d(out_dim, embed_dim, kernel_size = 1, stride = 1, padding = 0)
        
        self.qt_t = Quantized(num_embeddings, embed_dim)
        
        self.dec_t = Decoder(out_dim, out_dim//2, num_of_blocks, factor = 2)
        self.conv2 = nn.Conv2d(out_dim//2 + out_dim//2, embed_dim, kernel_size = 1, stride = 1, padding = 0)
           
        self.qt_b = Quantized(num_embeddings, embed_dim)
        
        self.dec_b = Decoder(out_dim//2, input_dim, num_of_blocks, factor = 4)
        
        self.conv3 = nn.Conv2d(embed_dim + out_dim//2, out_dim//2, kernel_size = 1, stride = 1, padding = 0)
    
    def forward(self, x):
        
        quant_t, quant_b, loss = self.encode(x)
        out = self.decode(quant_t, quant_b)
        return loss, out
    
    def encode(self, x):
        enc_b = self.enc_b(x)
        enc_t = self.enc_t(enc_b)
        
        quant_t = self.conv1(enc_t)
        loss_1, quant_t = self.qt_t(quant_t)
        #print(quant_t.shape)
        dec_t = self.dec_t(quant_t)
        #print(dec_t.shape)
        #print(enc_b.shape)
        enc_b = torch.cat([dec_t, enc_b], 1)
        #print(enc_b.shape)
        quant_b = self.conv2(enc_b)
        #print(quant_b.shape)
        loss_2, quant_b = self.qt_b(quant_b)
        
        return quant_t, quant_b, loss_1 + loss_2
    
    def decode(self, quant_t, quant_b):
        dec_t = self.dec_t(quant_t)
        quant = torch.cat([quant_b, dec_t], dim = 1)
        quant = self.conv3(quant)
        dec = self.dec_b(quant)
        return dec

In [8]:
model = VQVAE2(3, 256,20,512, 256).to('cuda')
data = torch.randn(1, 3, 32,32).to('cuda')
loss, out = model(data)

In [9]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms

In [10]:
#GatedPixelCNN

In [11]:
out.shape

torch.Size([1, 3, 32, 32])