In [67]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [115]:
class VAE(nn.Module):
    def __init__(self, in_channel: int,
                 latent_dim: int, 
                 hidden_dims: list = None
                ) -> None:
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]
        self.hidden_dims = hidden_dims

        # Encoder
        modules = []
        for dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channel, dim, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(dim),
                    nn.ReLU(),
                )
            )
            in_channel = dim
        self.encoder = nn.Sequential(*modules)
        self.mu = nn.Linear(in_channel * 4, latent_dim)
        self.log_var = nn.Linear(in_channel * 4, latent_dim)

        self.zt = nn.Linear(latent_dim, in_channel * 4)

        # Decoder
        hidden_dims.reverse()

        modules = []
        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.ReLU()
                )
            )
        self.decoder = nn.Sequential(*modules)
        self.output = nn.Sequential(
            nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(hidden_dims[-1]),
            nn.ReLU(),
            nn.Conv2d(hidden_dims[-1], 3, kernel_size=3, padding=1),
            nn.Tanh()
        )


    def encode(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, start_dim=1)
        mu = self.mu(x)
        log_var = self.log_var(x)
        return (mu, log_var)

    def decode(self, z):
        z = self.zt(z)
        z = z.view(-1, self.hidden_dims[0], 2, 2)
        x = self.decoder(z)
        x = self.output(x)
        return x

    def forward(self, x):
        bs, _, _, _ = x.shape
        mu, log_var = self.encode(x)
        std = torch.exp(0.5 * log_var)
        eps = torch.rand_like(std)

        # Reparametrization
        z = std * eps + mu
        re_x = self.decode(z)
        return re_x
    
    def loss(self, x, reg):
        re_x = self.forward(x)
        mu, log_var = self.encode(x)
        mse = F.mse_loss(x, re_x)
        kl_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = mse + reg * kl_loss
        return loss

    def inference(self, num_samples, device):
        z = torch.rand(num_samples, self.latent_dim).to(device=device)

        re_x = self.decode(z)
        return re_x

In [167]:
class VectorQuantizer(nn.Module):
    def __init__(self, dict_size, dimension, beta):
        super(VectorQuantizer, self).__init__()
        self.K = dict_size
        self.D = dimension
        self.beta = beta
        self.embedding = nn.Embedding(self.K, self.D)
        self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)

    def forward(self, code):
        code = code.permute(0, 2, 3, 1).contiguous() # [B, H, W, D]
        B, H, W, D = code.shape
        flat_code = code.view(-1, self.D) # [BHW, D]

        l2 = torch.sum(flat_code ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * torch.matmul(flat_code, self.embedding.weight.t()) # [BHW, K]
        min_indices = torch.argmin(l2, dim=1).unsqueeze(1) # [BHW, 1]

        onehot_encoding = torch.zeros((B * H * W, self.K), device=code.device)
        onehot_encoding.scatter_(1, min_indices, 1) # [BHW, K]

        cb_encoding = torch.matmul(onehot_encoding, self.embedding.weight) # [BHW, D]
        cb_encoding = cb_encoding.view(B, H, W, D)

        zq_loss = F.mse_loss(cb_encoding, code.detach())
        z_loss = F.mse_loss(cb_encoding.detach(), code)

        loss = zq_loss * self.beta + z_loss
        q_code = cb_encoding + (code - cb_encoding).detach()
        q_code = q_code.permute(0, 3, 1, 2).contiguous()
        return q_code, loss
    
class ResidualLayer(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ResidualLayer, self).__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.residule = nn.Sequential(
            nn.Conv2d(self.in_channel, self.out_channel, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(self.out_channel, self.out_channel, kernel_size=3, padding=1)
        )

    def forward(self, x):
        return x + self.residule(x)
    
class VQ_VAE(nn.Module):
    def __init__(self, in_channel: int,
                 embedding_dim: int, 
                 dict_size: int,
                 beta: float,
                 hidden_dims: list = None,
                ) -> None:
        super(VQ_VAE, self).__init__()
        if hidden_dims is None:
            hidden_dims = [128, 256]
        self.hidden_dims = hidden_dims

        # Encoder
        modules = []
        for dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channel, dim, kernel_size=3, stride=2, padding=1),
                    nn.ReLU(),
                )
            )
            in_channel = dim
        modules.append(
            nn.Sequential(
                nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1),
                nn.ReLU()
            )
        )
        for i in range(0, 6):
            modules.append(ResidualLayer(in_channel, in_channel))
        modules.append(nn.ReLU())

        modules.append(
            nn.Sequential(
                nn.Conv2d(in_channel, in_channel, kernel_size=1),
                nn.ReLU()
            )
        )
        
        self.encoder = nn.Sequential(*modules)
        self.vq = VectorQuantizer(dict_size, embedding_dim, beta)

        # Decoder
        hidden_dims.reverse()

        modules = []
        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=4, stride=2, padding=1, output_padding=0),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.ReLU()
                )
            )
        self.decoder = nn.Sequential(*modules)
        self.output = nn.Sequential(
            nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=4, stride=2, padding=1, output_padding=0),
            nn.BatchNorm2d(hidden_dims[-1]),
            nn.ReLU(),
            nn.Conv2d(hidden_dims[-1], 3, kernel_size=3, padding=1),
            nn.Tanh()
        )


    def encode(self, x):
        x = self.encoder(x)
        return x

    def decode(self, z):
        x = self.decoder(z)
        x = self.output(x)
        return x

    def forward(self, x):
        x = self.encode(x)
        q_code, cb_loss = self.vq(x)
        re_x = self.decode(q_code)
        return re_x, cb_loss
    
    def loss(self, x):
        re_x, cb_loss = self.forward(x)
        mse = F.mse_loss(x, re_x)

        loss = mse + cb_loss
        return loss

    def inference(self, num_samples, device):
        z = torch.rand(num_samples, self.latent_dim).to(device=device)

        re_x = self.decode(z)
        return re_x