Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VAE with variational nested dropout #60

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ $ tensorboard --logdir .
| SWAE (200 Projections) ([Code][swae_code], [Config][swae_config]) |[Link](https://arxiv.org/abs/1804.01947) | ![][28] | ![][27] |
| VQ-VAE (*K = 512, D = 64*) ([Code][vqvae_code], [Config][vqvae_config])|[Link](https://arxiv.org/abs/1711.00937) | ![][31] | **N/A** |
| DIP VAE ([Code][dipvae_code], [Config][dipvae_config]) |[Link](https://arxiv.org/abs/1711.00848) | ![][36] | ![][35] |
| VNDAE ([Code][vndae_code], [Config][vndae_config]) |[Link](https://arxiv.org/pdf/2101.11353.pdf) | ![][37] | ![][38] |


<!-- | Gamma VAE |[Link](https://arxiv.org/abs/1610.05683) | ![][16] | ![][15] |-->
Expand Down Expand Up @@ -189,6 +190,7 @@ Additionally, if you would like to contribute some models, please submit a PR.
[infovae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/info_vae.py
[vqvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
[dipvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/dip_vae.py
[vndae_code]: https://github.com/ralphc1212/PyTorch-VAE/blob/deaac5a3165ea1048cfb129aa72b0a1c33f55041/models/vnd_ae.py

[vae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vae.yaml
[cvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/cvae.yaml
Expand All @@ -208,6 +210,7 @@ Additionally, if you would like to contribute some models, please submit a PR.
[infovae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/infovae.yaml
[vqvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vq_vae.yaml
[dipvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/dip_vae.yaml
[vndae_config]: https://github.com/ralphc1212/PyTorch-VAE/blob/deaac5a3165ea1048cfb129aa72b0a1c33f55041/configs/vndae.yaml

[1]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/Vanilla%20VAE_25.png
[2]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_Vanilla%20VAE_25.png
Expand Down Expand Up @@ -244,6 +247,8 @@ Additionally, if you would like to contribute some models, please submit a PR.
[34]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaTCVAE_49.png
[35]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/DIPVAE_83.png
[36]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_DIPVAE_83.png
[37]: https://github.com/ralphc1212/PyTorch-VAE/blob/deaac5a3165ea1048cfb129aa72b0a1c33f55041/assets/recons_VNDAE_1.png
[38]: https://github.com/ralphc1212/PyTorch-VAE/blob/deaac5a3165ea1048cfb129aa72b0a1c33f55041/assets/VNDAE_1.png

[python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg
[python-url]: https://www.python.org/
Expand Down
Binary file added assets/VNDAE_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/recons_VNDAE_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions configs/vndae.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
model_params:
name: 'VNDAE'
in_channels: 3
latent_dim: 128


data_params:
data_path: "data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4


exp_params:
LR: 0.005
weight_decay: 0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265

trainer_params:
gpus: [0]
max_epochs: 50

logging_params:
save_dir: "logs/"
name: "VNDAE"
3 changes: 2 additions & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .vq_vae import *
from .betatc_vae import *
from .dip_vae import *

from .vnd_ae import *

# Aliases
VAE = VanillaVAE
Expand All @@ -35,6 +35,7 @@
'SWAE':SWAE,
'MIWAE':MIWAE,
'VQVAE':VQVAE,
'VNDAE':VNDAE,
'DFCVAE':DFCVAE,
'DIPVAE':DIPVAE,
'BetaVAE':BetaVAE,
Expand Down
218 changes: 218 additions & 0 deletions models/vnd_ae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *

TAU = 1.
PI = 0.95
RSV_DIM = 1
EPS = 1e-8
SAMPLE_LEN = 1.

class VNDAE(BaseVAE):


def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
**kwargs) -> None:
super(VNDAE, self).__init__()

self.latent_dim = latent_dim

modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]

# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim

self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)
self.fc_p_vnd = nn.Linear(hidden_dims[-1] * 4, latent_dim)

Pi = nn.Parameter(PI * torch.ones(latent_dim - RSV_DIM), requires_grad=False)

self.ZERO = nn.Parameter(torch.tensor([0.]), requires_grad=False)
self.ONE = nn.Parameter(torch.tensor([1.]), requires_grad=False)
self.pv = nn.Parameter(torch.cat([self.ONE, torch.cumprod(Pi, dim=0)])
* torch.cat([1 - Pi, self.ONE]), requires_grad=False)

# Build Decoder
modules = []

self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

hidden_dims.reverse()

for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)

self.decoder = nn.Sequential(*modules)

self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())

@staticmethod
def clip_beta(tensor, to=5.):
"""
Shrink all tensor's values to range [-to,to]
"""
return torch.clamp(tensor, -to, to)

def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)

# Split the result into mu, var and p_vnd components
# of the latent mixture
mu = self.fc_mu(result)
log_var = self.fc_var(result)
p_vnd = self.fc_p_vnd(result)

return [mu, log_var, p_vnd]

def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes
onto the image space.
:param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W]
"""
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result

def reparameterize(self, mu: Tensor, logvar: Tensor, p_vnd: Tensor) -> Tensor:
"""
Reparameterization trick to sample from the mixture posterior shown in Eq. 28 in [https://arxiv.org/pdf/2101.11353.pdf].
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:param p_vnd: (Tensor) Parameter for the Downhill distribution [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)

# Generate samples for the Downhill distribution

eps = torch.randn_like(std)
beta = torch.sigmoid(self.clip_beta(p_vnd[:,RSV_DIM:]))
ONES = torch.ones_like(beta[:,0:1])
qv = torch.cat([ONES, torch.cumprod(beta, dim=1)], dim = -1) * torch.cat([1 - beta, ONES], dim = -1)
s_vnd = F.gumbel_softmax(qv, tau=TAU, hard=True)

cumsum = torch.cumsum(s_vnd, dim=1)
dif = cumsum - s_vnd
mask0 = dif[:, 1:]
mask1 = 1. - mask0
s_vnd = torch.cat([torch.ones_like(p_vnd[:,:RSV_DIM]), mask1], dim = -1)

return (eps * std + mu) * s_vnd

def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var, p_vnd = self.encode(input)
z = self.reparameterize(mu, log_var, p_vnd)
return [self.decode(z), input, mu, log_var, p_vnd]

def loss_function(self,
*args,
**kwargs) -> dict:
"""
Computes the VNDAE loss function shown in Eq.29 in [https://arxiv.org/pdf/2101.11353.pdf].
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
p_vnd = args[4]
beta = torch.sigmoid(self.clip_beta(p_vnd[:,RSV_DIM:]))
ONES = torch.ones_like(beta[:,0:1])
qv = torch.cat([ONES, torch.cumprod(beta, dim=1)], dim = -1) * torch.cat([1 - beta, ONES], dim = -1)

ZEROS = torch.zeros_like(beta[:, 0:1])
cum_sum = torch.cat([ZEROS, torch.cumsum(qv[:, 1:], dim = 1)], dim = -1)[:, :-1]
coef1 = torch.sum(qv, dim=1, keepdim=True) - cum_sum
coef1 = torch.cat([torch.ones_like(p_vnd[:,:RSV_DIM]), coef1], dim = -1)

kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss =F.mse_loss(recons, input)

kld_gaussian = -0.5 * (1 + log_var - mu ** 2 - log_var.exp())

kld_weighted_gaussian = torch.diagonal(kld_gaussian.mm(coef1.t()), 0).mean()

log_frac = torch.log(qv / self.pv + EPS)
kld_vnd = torch.diagonal(qv.mm(log_frac.t()), 0).mean()

kld_loss = kld_vnd + kld_weighted_gaussian
loss = recons_loss + kld_weight * kld_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD': - kld_loss.detach()}

def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space given fixed width SAMPLE_LEN.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)

z = torch.cat([z[:, :int(SAMPLE_LEN * self.latent_dim)], torch.zeros_like(z[:, :int((1 - SAMPLE_LEN) * self.latent_dim)])], dim = -1)
z = z.to(current_device)

samples = self.decode(z)
return samples

def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""

return self.forward(x)[0]