In [1]:
# default_exp vae

In [2]:
#hide
%load_ext autoreload
%autoreload 2

# Variational Autoencoder

> And its variations.

In [3]:
# export
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

import typing
from typing import Sequence, Union, Tuple

from generative_models.layers import scale, unscale

In [4]:
# export
class VAEOutput(typing.NamedTuple):
    pred:Tensor
    kl_loss:Tensor

In [5]:
# export
class VAE(nn.Module):

    def __init__(self, encoder:nn.Module, decoder:nn.Module, beta:float=1.):
        super().__init__()
        self.encoder, self.decoder = encoder, decoder
        self.beta = beta
        
    def forward(self, x):
        mu, logvar = self.encoder(x).chunk(2, -1)
        z, kl_loss = self.reparametrize(mu, logvar)
        out = self.decoder(z)
        return VAEOutput(out, kl_loss)

    def reparametrize(self, mu, logvar):
        bs = mu.size(0)
        std = logvar.mul(0.5).exp_()
        if self.training:
            z = torch.randn_like(mu, requires_grad=False)*std + mu
        else:
            z = mu

        kl_loss = self.beta * 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)) / bs
        return z, kl_loss

    @torch.no_grad()
    def sample(self, z=None, n=100):
        if z is None:
            z = torch.randn(n, self.d_z, device=device)
        return unscale(self.decoder(z)) 

    @torch.no_grad()
    def reconstruct(self, x):
        self.eval()
        return unscale(self(x)[0])

## Vector Quantized VAE

In [6]:
# export
from torch.autograd import Function

In [7]:
#export
class VQPseudoGrad(Function):
    @staticmethod
    def forward(ctx, z, q):
        return q
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None


In [13]:
# export
class VectorQuantizer(nn.Module):

    def __init__(self, k:int, d:int, commitment_cost:float=0.25):
        super().__init__()
        self.k = k
        self.commitment_cost = commitment_cost
        self.embedding = nn.Parameter(torch.empty(k, d))
        nn.init.uniform_(self.embedding, -1/k, 1/k)

    def forward(self, z):
        b,c,h,w = z.size()
        z_ = z.permute(0,2,3,1)
        e = self.embedding
        distances = ((z_*z_).sum(-1, keepdim=True)
                    -2*torch.einsum('...d, nd -> ...n', z_, e)
                    +(e*e).sum(-1, keepdim=True).t())
        code = distances.argmin(-1)
        zq = F.embedding(code, e).permute(0,3,1,2).contiguous()
        
        e_latent_loss = F.mse_loss(zq.detach(), z)
        q_latent_loss = F.mse_loss(zq, z.detach())
        loss = q_latent_loss + e_latent_loss * self.commitment_cost

        avg_probs = F.one_hot(code.flatten(), self.k).float().mean(0)
        perplexity = (-avg_probs*(avg_probs+1e-10).log()).sum().exp()
        return VQPseudoGrad.apply(z, zq), loss, code, perplexity
    
    def extra_repr(self):
        return (f'(embedding): k={self.embedding.size(0)}, d={self.embedding.size(1)}')

In [14]:
# export
class EMA(nn.Module):

    def __init__(self, size:Tuple[int], gamma:float):
        super().__init__()
        self.register_buffer("avg", torch.zeros(*size))
        self.gamma = gamma
        self.cor = 1

    def update(self, val):
        self.cor *= self.gamma
        self.avg += (val - self.avg) * (1-self.gamma)

    @property
    def value(self):
        return self.avg / (1. - self.cor)

    def updated_value(self, val):
        self.update(val)
        return self.value

In [15]:
# hide
ema = EMA((1, ), 0.9)

for i in range(1,11):
    print(ema.updated_value(i))

tensor([1.])
tensor([1.5263])
tensor([2.0701])
tensor([2.6313])
tensor([3.2097])
tensor([3.8052])
tensor([4.4176])
tensor([5.0466])
tensor([5.6920])
tensor([6.3534])


In [21]:
# export
class VectorQuantizerEMA(nn.Module):

    def __init__(self, k:int, d:int, commitment_cost:float=0.25, gamma=0.99, epsilon=1e-5):
        super().__init__()
        self.commitment_cost = commitment_cost
        self.gamma, self.epsilon = gamma, epsilon
        self.k = k
        self.register_buffer("embedding", (torch.empty(k, d)))
        # self.embedding = nn.Parameter(torch.empty(k, d))
        nn.init.uniform_(self.embedding, -1/k, 1/k)
        self.ema_cluster_size = EMA((k, ), gamma=gamma)
        self.ema_cluster_sum = EMA(self.embedding.size(), gamma=gamma)

    def forward(self, z):
        if z.dim() == 2:
            nd = 1
            b,c = z.size()
            z_ = z
        if z.dim() == 4:
            nd = 2
            b,c,h,w = z.size()
            z_ = z.permute(0,2,3,1).view(-1, c)
        e = self.embedding
        distances = ((z_*z_).sum(-1, keepdim=True)
                    -2*torch.einsum('...d, nd -> ...n', z_, e)
                    +(e*e).sum(-1, keepdim=True).t())
        code = distances.argmin(-1)
        code_oh = F.one_hot(code, self.k)
        zq = F.embedding(code, e).clone().requires_grad_()
        if nd == 2:
            zq = zq.view(b,h,w,c).permute(0,3,1,2).contiguous()
        
        e_latent_loss = F.mse_loss(zq.detach(), z)
        loss = e_latent_loss * self.commitment_cost

        # EMA update for the codebook
        if self.training:
            cluster_size = code_oh.sum(0)
            upd_ema_cluster_size = self.ema_cluster_size.updated_value(cluster_size)
            n = cluster_size.sum()
            upd_ema_cluster_size = ((upd_ema_cluster_size + self.epsilon) /
                                    (n + self.k * self.epsilon) * n)
            cluster_sum = torch.zeros_like(self.embedding).scatter_add_(0, code.unsqueeze(-1).expand_as(z_), z_)
            upd_ema_cluster_sum = self.ema_cluster_sum.updated_value(cluster_sum)
            
            self.embedding = upd_ema_cluster_sum / upd_ema_cluster_size[..., None]
        
        avg_probs = code_oh.float().mean(0)
        perplexity = (-avg_probs*(avg_probs+1e-10).log()).sum().exp()
        return VQPseudoGrad.apply(z, zq), loss, code, perplexity
    
    def extra_repr(self):
        return (f'(embedding): k={self.embedding.size(0)}, d={self.embedding.size(1)}')

In [22]:
# hide
vectors = torch.randn(100, 8)
vq = VectorQuantizerEMA(100, 8)
for i in range(1000):
    idx = torch.randint(0,100, (16,))
    x = vectors[idx]
    loss = vq(x)[1]
    if (i+1)%50 == 0:
        print(loss.item())


0.1126675084233284
0.07073181122541428
0.10676535964012146
0.08209685236215591
0.08711182326078415
0.08703210204839706
0.10084635019302368
0.10472283512353897
0.08259377628564835
0.06909796595573425
0.059976726770401
0.06459784507751465
0.08399104326963425
0.07727716863155365
0.10665100812911987
0.08693363517522812
0.05164112523198128
0.07124526798725128
0.07908015698194504
0.06688578426837921


In [23]:
# export
class VQVAE(nn.Module):
    
    def __init__(self, encoder, decoder, k:int, d:int, commitment_cost:float=0.25, use_ema:bool=False):
        super().__init__()
        self.encoder, self.decoder = encoder, decoder
        self.quantize = (VectorQuantizerEMA(k, d, commitment_cost) if use_ema else
                         VectorQuantizer(k, d, commitment_cost))

    def forward(self, x):
        ze = self.encoder(x)
        zq, vq_loss, code, ppl = self.quantize(ze)
        x_hat = self.decoder(zq)
        return x_hat, vq_loss, code, ppl

    @torch.no_grad()
    def encode(self, x):
        ze = self.encoder(x)
        code = self.quantize(ze)[2]
        return code

    @torch.no_grad()
    def decode(self, code):
        zq = F.embedding(code, self.quantize.embedding)
        if zq.dim() == 4:
            zq = zq.permute(0,3,1,2).contiguous()
        return self.decoder(zq)

In [26]:
from generative_models.layers import MLP

encoder = MLP(10, 8, 16, 2)
decoder = MLP(8, 10, 16, 2)

vqvae = VQVAE(encoder, decoder, 10, 8, use_ema=True)

x = torch.randn(4, 10)
x_hat, vq_loss, _, _ = vqvae(x)

In [28]:
loss = F.mse_loss(x_hat, x) + vq_loss
loss.backward()

In [19]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_layers.ipynb.
Converted 01_training.ipynb.
Converted 02_made.ipynb.
Converted 03_pixelcnn.ipynb.
Converted 04_vae.ipynb.
Converted 10_experiments.pixelcnn.ipynb.
Converted index.ipynb.
