From b3eae1108fa3f07578bbb83a11cadb3b973a723f Mon Sep 17 00:00:00 2001 From: Zhijie Zheng <42537498+0809zheng@users.noreply.github.com> Date: Mon, 8 Jul 2024 14:56:00 +0800 Subject: [PATCH] Add files via upload --- _posts/deep-learning/2022-04-01-vae.md | 452 +++++++++++++------------ 1 file changed, 227 insertions(+), 225 deletions(-) diff --git a/_posts/deep-learning/2022-04-01-vae.md b/_posts/deep-learning/2022-04-01-vae.md index 770b8170..015b7b43 100644 --- a/_posts/deep-learning/2022-04-01-vae.md +++ b/_posts/deep-learning/2022-04-01-vae.md @@ -1,225 +1,227 @@ ---- -layout: post -title: '变分自编码器(Variational Autoencoder)' -date: 2022-04-01 -author: 郑之杰 -cover: 'https://pic.downk.cc/item/5fb715b8b18d627113d7ae88.jpg' -tags: 深度学习 ---- - -> Variational Autoencoder. - -本文目录: -1. 变分自编码器之“自编码器”:概率编码器与概率解码器 -2. 变分自编码器之“变分”:优化目标与重参数化 -3. 变分自编码器的各种变体 - - -# 1. 变分自编码器之“自编码器”:概率编码器与概率解码器 -**变分自编码器(Variational Autoencoder,VAE)**是一种深度生成模型,旨在学习已有数据集$$\{x_1,x_2,...,x_n\}$$的概率分布$p(x)$,并从数据分布中采样生成新的数据$\hat{x}$~$p(x)$。由于已有数据(也称**观测数据, observed data**)的概率分布形式是未知的,**VAE**把输入数据编码到**隐空间(latent space)**中,构造**隐变量(latent variable)**的概率分布$p(z)$,从隐变量中采样并重构新的数据,整个过程与自编码器类似。**VAE**的概率模型如下: - -$$ p(x) = \sum_{z}^{} p(x,z) = \sum_{z}^{}p(x|z)p(z) $$ - -如果人为指定隐变量的概率分布$p(z)$形式,则可以从中采样并通过解码器$p(x \| z)$(通常用神经网络拟合)生成新的数据。然而注意到此时隐变量的概率分布$p(z)$与输入数据无关,即给定一个输入数据$x_n$,从$p(z)$随机采样并重构为$\hat{x}_n$,将无法保证$x_n$与$\hat{x}_n$的对应性!此时生成模型常用的优化指标$Distance(x_n,\hat{x}_n)$等也无法使用。 - -在**VAE**中并不是**直接指定**隐变量的概率分布$p(z)$形式,而是为每个输入数据$x_n$指定一个**后验分布**$q(z \| x_n)$(通常为标准正态分布),则从该后验分布中采样并重构的$\hat{x}_n$对应于$x_n$。**VAE**指定后验分布$q(z \| x_n)$为标准正态分布$$\mathcal{N}(0,I)$$,则隐变量分布$p(z)$实际上也会是标准正态分布$$\mathcal{N}(0,I)$$: - -$$ p(z) = \sum_{x}^{} q(z|x)p(x) = \sum_{x}^{} \mathcal{N}(0,I)p(x)= \mathcal{N}(0,I)\sum_{x}^{} p(x)= \mathcal{N}(0,I) $$ - -**VAE**使用编码器(通常用神经网络)拟合后验分布$q(z \| x_n)$的均值$\mu_n$和方差$\sigma_n^2$(其数量由每批次样本量决定),通过训练使其接近标准正态分布$$\mathcal{N}(0,I)$$。实际上后验分布不可能**完全精确**地被拟合为标准正态分布,因为这样会使得$q(z \| x_n)$完全独立于输入数据$x_n$,从而使得重构效果极差。**VAE**的训练过程中隐含地存在着**对抗**的过程,最终使得$q(z \| x_n)$保留一定的输入数据$x_n$信息,并且对输入数据也具有一定的重构效果。 - -**VAE**的整体结构如下图所示。从给定数据中学习后验分布$q(z \| x)$的均值$\mu_n$和方差$\sigma_n^2$的过程称为**推断(inference)**,实现该过程的结构被称作**概率编码器(probabilistic encoder)**。从后验分布的采样结果中重构数据的过程$p(x \| z)$称为**生成(generation)**,实现该过程的结构被称作**概率解码器(probabilistic decoder)**。 - -![](https://pic.imgdb.cn/item/626b9206239250f7c5a98639.jpg) - -### ⚪讨论:后验分布可以选取其他分布吗? - -理论上,后验分布$q(z \| x_n)$可以选取任意可行的概率分布形式。然而从后续讨论中会发现对后验分布的约束是通过**KL**散度实现的,**KL**散度对于概率为$0$的点会发散,选择概率密度全局非负的标准正态分布$$\mathcal{N}(0,I)$$不会出现这种问题,且具有可以计算梯度的简洁的解析解。此外,由于服从正态分布的独立随机变量的和仍然是正态分布,因此隐空间中任意两点间的线性插值也是有意义的,并且可以通过线性插值获得一系列生成结果的展示。 - -### ⚪讨论:VAE的Bayesian解释 - -自编码器**AE**将观测数据$x$编码为特征向量$z$,每一个特征向量对应特征空间中的一个离散点,所有特征向量的分布是无序、散乱的,并且无法保证不存在特征向量的空间点能够重构出真实样本。**VAE**是**AE**的**Bayesian**形式,将特征向量看作随机变量,使其能够覆盖特征空间中的一片区域。进一步通过强迫所有数据的特征向量服从多维正态分布,从而解耦特征维度,使得特征的空间分布有序、规整。 - - - -# 2. 变分自编码器之“变分”:优化目标与重参数化 -**VAE**是一种隐变量模型$p(x,z)$,其优化目标为最大化观测数据的对数似然$\log p(x)=\log \sum_{z}^{} p(x,z)$。该问题是不可解的,因此采用[变分推断](https://0809zheng.github.io/2020/03/25/variational-inference.html)求解。变分推断的核心思想是引入一个新的分布$q(z \| x)$作为后验分布$p(z \| x)$的近似,从而构造**对数似然**$\log p(x)$的置信下界**ELBO**(也称**变分下界**, **variational lower bound**),通过最大化**ELBO**来代替最大化$\log p(x)$。采用[Jensen不等式](https://0809zheng.github.io/2022/07/20/jenson.html)可以快速推导**ELBO**的表达式: - -$$ \begin{aligned} \log p(x) &= \log \sum_{z}^{} p(x,z) = \log \sum_{z}^{} \frac{p(x,z)}{q(z|x)}q(z|x) \\ &= \log \Bbb{E}_{z \text{~} q(z|x)}[\frac{p(x,z)}{q(z|x)}] \geq \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{p(x,z)}{q(z|x)}] \end{aligned} $$ - -上式表明变分下界**ELBO**是原优化目标$\log p(x)$的一个下界,两者的差距可以通过对$\log p(x)$的另一种写法获得: - -$$ \begin{aligned} \log p(x) &= \sum_{z}^{} q(z|x)\log p(x)= \Bbb{E}_{z \text{~} q(z|x)}[\log p(x)]\\ &= \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{p(x,z)}{p(z|x)}] = \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{p(x,z)}{p(z|x)}\frac{q(z|x)}{q(z|x)}] \\ &= \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{p(x,z)}{q(z|x)}] + \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{q(z|x)}{p(z|x)}] \end{aligned} $$ - -因此**VAE**的变分下界与原目标之间存在的**gap**为$\Bbb{E}_{z \text{~} q(z\|x)}[\log \frac{q(z\|x)}{p(z\|x)}]=KL(q(z\|x)\|\|p(z\|x))$。让**gap**为$0$的条件是$q(z\|x)=p(z\|x)$,即找到一个与真实后验分布$p(z\|x)$相同的分布$q(z\|x)$。然而$q(z\|x)$通常假设为较为简单的分布形式(如正态分布),不能拟合足够复杂的分布。因此**VAE**通常只是一个近似模型,优化的是代理(**surrogate**)目标,生成的图像比较模糊。 - -在**VAE**中,最大化**ELBO**等价于最小化如下损失函数: - -$$ \begin{aligned} \mathcal{L} &= -\mathbb{E}_{z \text{~} q(z|x)} [\log \frac{p(x,z)}{q(z|x)}] = -\mathbb{E}_{z \text{~} q(z|x)} [\log \frac{p(x|z)p(z)}{q(z|x)}] \\ &= -\mathbb{E}_{z \text{~} q(z|x)} [\log p(x | z)] - \mathbb{E}_{z \text{~}q(z|x)} [\log \frac{p(z)}{q(z|x)}] \\ &= \mathbb{E}_{z \text{~} q(z|x)} [-\log p(x | z)] + KL[q(z|x)||p(z)] \end{aligned} $$ - -直观上损失函数可以分成两部分:其中$\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x \| z)]$表示生成模型$p(x\|z)$的**重构损失**,$KL[q(z\|x)\|\|p(z)]$表示后验分布$q(z\|x)$的**正则化项**(**KL**损失)。这两个损失并不是独立的,因为重构损失很小表明$p(x\|z)$置信度较大,即解码器重构比较准确,则编码器$q(z\|x)$不会太随机(即应和$x$相关性较高),此时**KL**损失不会小;另一方面**KL**损失很小表明编码器$q(z\|x)$随机性较高(即和$x$无关),此时重构损失不可能小。因此**VAE**的损失隐含着对抗的过程,在优化过程中总损失减小才对应模型的收敛。下面分别讨论这两种损失的具体形式。 - -## (1) 后验分布$q(z\|x)$的正则化项 - -损失$KL[q(z\|x)\|\|p(z)]$衡量后验分布$q(z\|x)$和先验分布$p(z)$之间的**KL**散度。$q(z\|x)$优化的目标是趋近标准正态分布,此时$p(z)$指定为标准正态分布$z$~$$\mathcal{N}(0,I)$$。$q(z\|x)$通过神经网络进行拟合(即概率编码器),其形式人为指定为**多维对角正态分布** $$\mathcal{N}(\mu,\sigma^{2})$$。 - -由于两个分布都是正态分布,**KL**散度有闭式解(**closed-form solution**),计算如下: - -$$ \begin{aligned} KL[q(z|x)||p(z)] &= KL[\mathcal{N}(\mu,\sigma^{2})||\mathcal{N}(0,1)] \\ &= \int_{}^{} \frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}} \log \frac{\frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}}}{\frac{1}{\sqrt{2\pi}}e^{-\frac{x^2}{2}}} dx \\&= \int_{}^{} \frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}} [-\frac{1}{2}\log \sigma^2 + \frac{x^2}{2}-\frac{(x-\mu)^2}{2\sigma^2}] dx \\ &= \frac{1}{2} (-\log \sigma^2 + \mu^2+\sigma^2-1) \end{aligned} $$ - -在实际实现时拟合$\log \sigma^2$而不是$\sigma^2$,因为$\sigma^2$总是非负的,需要增加激活函数进行限制;而$\log \sigma^2$的取值是任意的。**KL**损失的**Pytorch**实现如下: - -```python -kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) -``` - - -## (2) 生成模型$p(x\|z)$的重构损失 - -重构损失$\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x \| z)]$表示把观测数据映射到隐空间后再重构为原数据的过程。其中生成模型$p(x\|z)$也是通过神经网络进行拟合的(即概率解码器),根据所处理数据的类型不同,$p(x\|z)$应选择不同的分布形式。 - -### ⚪二值数据:伯努利分布 -如果观测数据$x$为二值数据(如二值图像),则生成模型$p(x\|z)$建模为伯努利分布: - -$$ p(x|z)= \begin{cases} \rho(z) , & x=1 \\ 1-\rho(z), & x=0 \end{cases} = (\rho(z))^{x}(1-\rho(z))^{1-x} $$ - -使用神经网络拟合参数$\rho$: - -$$ \rho = \mathop{\arg \max}_{\rho} \log p(x|z) = \mathop{\arg \min}_{\rho} -x\log \rho(z)-(1-x)\log(1-\rho(z)) $$ - -上式表示交叉熵损失函数,且$\rho(z)$需要经过**sigmoid**等函数压缩到$[0,1]$。 - -### ⚪一般数据:正态分布 - -对于一般的观测数据$x$,将生成模型$p(x\|z)$建模为具有固定方差$\sigma^2_0$的正态分布: - -$$ p(x|z) = \frac{1}{\sqrt{2\pi\sigma_0^2}}e^{-\frac{(x-\mu)^2}{2\sigma_0^2}} $$ - -使用神经网络拟合参数$\mu$: - -$$ \mu = \mathop{\arg \max}_{\mu} \log p(x|z) = \mathop{\arg \min}_{\mu} \frac{(x-\mu)^2}{2\sigma_0^2} $$ - -上式表示**均方误差(MSE)**,**Pytorch**实现如下: - -```python -recons_loss = F.mse_loss(recons, input, reduction = 'sum') -``` - -注意`reduction`参数可选`'sum'`和`'mean'`,应该使用`'sum'`,这使得损失函数计算与原式保持一致。笔者在实现时曾选用`'mean'`,导致即使训练损失有下降,也只能生成噪声图片,推测是因为取平均使重构损失误差占比过小,无法正常训练。 - - -## (3) 重参数化技巧 -**VAE**的损失函数如下: - -$$ \mathcal{L} = \mathbb{E}_{z \text{~} q(z|x)} [-\log p(x | z)] + KL[q(z|x)||p(z)] $$ - -其中期望$$\mathbb{E}_{z \text{~} q(z\|x)} [\cdot]$$表示从从$q(z\|x)$中采样$z$的过程。由于采样过程是不可导的,不能直接参与梯度传播,因此引入[重参数化(reparameterization)技巧](https://0809zheng.github.io/2022/04/24/repere.html)。 - -已经假设$z$~$q(z\|x)$服从$$\mathcal{N}(\mu,\sigma^{2})$$,则$\epsilon=\frac{z-\mu}{\sigma}$服从标准正态分布$$\mathcal{N}(0,I)$$。因此有如下关系: - -$$ z = \mu + \sigma \cdot \epsilon $$ - -则从$$\mathcal{N}(0,I)$$中采样$\epsilon$,再经过参数变换构造$z$,可使得采样操作不用参与梯度下降,从而实现模型端到端的训练。 - -![](https://pic.imgdb.cn/item/627a22c909475431292e6d8d.jpg) - -在实现时对于每个样本只进行一次采样,采样的充分性是通过足够多的批量样本与训练轮数来保证的。则损失函数也可写作: - -$$ \mathcal{L} = -\log p(x | z) + KL[q(z|x)||p(z)], \quad z \text{~} q(z|x)$$ - -重参数化技巧的**Pytorch**实现如下: - -```python -def reparameterize(mu, log_var): - std = torch.exp(0.5 * log_var) - eps = torch.randn_like(std) - return mu + eps * std -``` - -### ⚪讨论:VAE的另一种建模方式 -对于一批已有样本,代表一个真实但形式未知的概率分布$\tilde{p}(x)$,可以构建一个带参数$\phi$的后验分布$q_{\phi}(z\|x)$,从而组成联合分布$q(x,z)=\tilde{p}(x)q_{\phi}(z\|x)$。如果人为定义一个先验分布$p(z)$,并构建一个带参数$\theta$的生成分布$p_{\theta}(x\|z)$,则可以构造另一个联合分布$p(x,z)=p(z)p_{\theta}(x\|z)$。**VAE**的目的是联合分布$q(x,z),p(x,z)$足够接近,因此最小化两者之间的**KL**散度: - -$$ \begin{aligned} KL(q(x,z)||p(x,z)) &= \mathbb{E}_{q(x,z)}[\log \frac{q(x,z)}{p(x,z)}] = \mathbb{E}_{\tilde{p}(x)} [\mathbb{E}_{q_{\phi}(z|x)}[\log \frac{\tilde{p}(x)q_{\phi}(z|x)}{p(z)p_{\theta}(x|z)}]] \\& = \mathbb{E}_{\tilde{p}(x)} [\mathbb{E}_{q_{\phi}(z|x)}[-\log p_{\theta}(x|z)] + \mathbb{E}_{q_{\phi}(z|x)}[\log \frac{q_{\phi}(z|x)}{p(z)}] + \mathbb{E}_{q_{\phi}(z|x)}[\log \tilde{p}(x)]] \\ &= \mathbb{E}_{\tilde{p}(x)} [\mathbb{E}_{q_{\phi}(z|x)}[-\log p_{\theta}(x|z)] + KL(q_{\phi}(z|x)||p(z)) + Const.] \end{aligned} $$ - -### ⚪讨论:KL散度与互信息 - -上述联合分布的**KL**散度也可写作: - -$$ \begin{aligned} KL(q(x,z)||p(x,z)) & = \mathbb{E}_{\tilde{p}(x)} [\mathbb{E}_{q_{\phi}(z|x)}[\log \frac{\tilde{p}(x)q_{\phi}(z|x)}{p(z)p_{\theta}(x|z)}]] \\& = \mathbb{E}_{\tilde{p}(x)} [\mathbb{E}_{q_{\phi}(z|x)}[\log \frac{q_{\phi}(z|x)}{p(z)}]] - \mathbb{E}_{\tilde{p}(x)} [\mathbb{E}_{q_{\phi}(z|x)}[\log \frac{p_{\theta}(x|z)}{\tilde{p}(x)}]] \\& = \mathbb{E}_{\tilde{p}(x)} [KL(q_{\phi}(z|x)||p(z))] - \mathbb{E}_{\tilde{p}(x)} [\mathbb{E}_{q_{\phi}(z|x)}[\log \frac{p_{\theta}(x,z)}{\tilde{p}(x)p(z)}]] \end{aligned} $$ - -其中第一项为隐变量$z$的后验分布与先验分布之间的**KL**散度,第二项为观测变量$x$与为隐变量$z$之间的点互信息。因此**VAE**的优化目标也可以解释为最小化隐变量**KL**散度的同时最大化隐变量与观测变量的互信息。 - -# 3. 变分自编码器的各种变体 - -**VAE**的损失函数共涉及三个不同的概率分布:由概率编码器表示的后验分布$q(z\|x)$、隐变量的先验分布$p(z)$以及由概率解码器表示的生成分布$p(x\|z)$。对**VAE**的各种改进可以落脚于对这些概率分布的改进: - -- 后验分布$q(z\|x)$:后验分布为模型引入了正则化;一种改进思路是通过调整后验分布的正则化项增强模型的解耦能力(如$\beta$-**VAE**, **Disentangled** $\beta$-**VAE**, **InfoVAE**, **DIP-VAE**, **FactorVAE**, $\beta$-**TCVAE**, **HFVAE**)。 -- 先验分布$p(z)$:先验分布描绘了隐变量分布的隐空间;一种改进思路是通过引入标签实现半监督学习(如**CVAE**, **CMMA**);一种改进思路是通过对隐变量离散化实现聚类或分层特征表示(如**Categorical VAE**, **Joint VAE**, **VQ-VAE**, **VQ-VAE-2**);一种改进思路是更换隐变量的概率分布形式(如**Hyperspherical VAE**, **TD-VAE**, **f-VAE**, **NVAE**)。 -- 生成分布$p(x\|z)$:生成分布代表模型的数据重构能力;一种改进思路是将均方误差损失替换为其他损失(如**EL-VAE**, **DFCVAE**, **LogCosh VAE**)。 -- 改进整体损失函数:也有方法通过调整整体损失改进模型,如紧凑变分下界(如**IWAE**, **MIWAE**)或引入**Wasserstein**距离(如**WAE**, **SWAE**)。 -- 改进模型结构:如[BN-VAE](https://0809zheng.github.io/2022/04/18/bnvae.html)通过引入**BatchNorm**缓解**KL**散度消失问题(指较强的解码器允许训练时**KL**散度项$KL[q(z\|x)\|\|p(z)]=0$);引入对抗训练(如[AAE](https://0809zheng.github.io/2022/02/20/aae.html), [VAE-GAN](https://0809zheng.github.io/2022/02/17/vaegan.html))。 - - - - -| 方法 | 损失函数 | -| :---: | :---: | -| VAE | $\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)]+KL[q(z\|x)\|\|p(z)]$ | -| [CVAE](https://0809zheng.github.io/2022/04/02/cvae.html)
引入条件 | $\mathbb{E}_{z \text{~} q(z\|x,y)} [-\log p(x\|z,y)]+KL[q(z\|x,y)\|\|p(z\|y)]$ | -| [CMMA](https://0809zheng.github.io/2022/04/03/cmma.html)
隐变量$z$由标签$y$决定 | $\mathbb{E}_{z \text{~} q(z\|x,y)} [-\log p(x\|z)]+KL[q(z\|x,y)\|\|p(z\|y)]$ | -| [β-VAE](https://0809zheng.github.io/2020/12/02/bvae.html)
特征解耦 | $\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)]+\beta \cdot KL[q(z\|x)\|\|p(z)]$ | -| [Disentangled β-VAE](https://0809zheng.github.io/2020/12/03/bvae2.html)
特征解耦 | $\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)]+\gamma \cdot \|KL[q(z\|x)\|\|p(z)]-C\|$ | -| [InfoVAE](https://0809zheng.github.io/2020/12/04/infovae.html)
特征解耦 | $\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)]+(1-\alpha) \cdot KL[q(z\|x)\|\|p(z)]+(\alpha+\lambda-1)\cdot \mathcal{D}_Z(q(z),p(z))$ | -| [DIP-VAE](https://0809zheng.github.io/2022/04/17/dipvae.html)
分离推断先验 | $$\begin{aligned} &\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)]+KL[q(z\|x)\|\|p(z)]\\&+\lambda_{od} \sum_{i \ne j} [\text{Cov}_{q(z)}[z]]^2_{ij}+\lambda_{d} \sum_{i} ([\text{Cov}_{q(z)}[z]]_{ii}-1)^2 \end{aligned}$$ | -| [FactorVAE](https://0809zheng.github.io/2022/04/15/factorvae.html)
特征解耦 | $$\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)]+KL[q(z\|x)\|\|p(z)] +\gamma KL(q(z)\|\|\prod_{j}q(z_j))$$ | -| [β-TCVAE](https://0809zheng.github.io/2022/04/05/btcvae.html)
分离全相关项 | $$\begin{aligned} & \mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)]+\alpha KL(q(z,x)\|\|q(z)p(x)) \\ & +\beta KL(q(z)\|\|\prod_{j}q(z_j)) +\gamma \sum_{j}KL(q(z_j)\|\|p(z_j)) \end{aligned}$$ | -| [HFVAE](https://0809zheng.github.io/2022/04/16/hfvae.html)
隐变量特征分组 | $$\begin{aligned}&\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)] + \sum_{i}KL[q(z_{i})\|\|p(z_{i})] + \alpha KL(q(z,x)\|\|q(z)p(x)) \\ &+\beta \Bbb{E}_{q(z)}[\log \frac{q(z)}{\prod_{g}q(z^g)}-\log \frac{p(z)}{\prod_{g}p(z^g)}] + \gamma \sum_{g} \Bbb{E}_{q(z^g)}[\log \frac{q(z^g)}{\prod_{j}q(z^g_j)}-\log \frac{p(z^g)}{\prod_{j}p(z^g_j)}]\end{aligned}$$ | -| [Categorical VAE](https://0809zheng.github.io/2022/04/10/catevae.html)
离散隐变量: **Gumbel Softmax** | $\mathbb{E}_{z \text{~} q(c\|x)} [-\log p(x\|c)]+ KL[q(c\|x)\|\|p(c)]$ | -| [Joint VAE](https://0809zheng.github.io/2022/04/11/jointvae.html)
连续+离散隐变量 | $\mathbb{E}_{z,c \text{~} q(z,c\|x)} [-\log p(x\|z,c)]+\gamma_z \cdot \|KL[q(z\|x)\|\|p(z)]-C_z\|+\gamma_c \cdot \|KL[q(c\|x)\|\|p(c)]-C_c\|$ | -| [VQ-VAE](https://0809zheng.github.io/2020/11/10/vqvae.html)
向量量化隐变量 | $\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z_e + sg[z_q-z_e])] + \|\| sg[z_e] - z_q \|\|_2^2 + \beta \|\| z_e - sg[z_q] \|\|_2^2$ | -| [VQ-VAE-2](https://0809zheng.github.io/2020/11/11/vqvae2.html)
VQ-VAE分层 | 同上,$z=\{z^{bottom},z^{top}\}$ | -| [Hyperspherical VAE](https://0809zheng.github.io/2022/04/19/svae.html)
引入**vMF**分布 | $$\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)],\quad p(z)\text{~}C_{d,\kappa} e^{\kappa <\mu(x),z>}$$ | -| [TD-VAE](https://0809zheng.github.io/2022/04/21/tdvae.html)
**Markov**链状态空间 | $$\begin{aligned} \Bbb{E}_{(z_{t_1},z_{t_2}) \text{~} q}[ & -\log p_D(x_{t_2}\|z_{t_2}) -\log p_B(z_{t_1}\|b_{t_1}) -\log p_T(z_{t_2}\|z_{t_1}) \\ & +\log p_B(z_{t_2}\|b_{t_2})+\log q_S(z_{t_1}\|z_{t_2},b_{t_1},b_{t_2})] \end{aligned}$$ | -| [f-VAE](https://0809zheng.github.io/2022/04/22/fvae.html)
引入流模型 | $$\mathbb{E}_{u \text{~} q(u)} [-\log p(x\|F_x(u))-\log \|\det[\frac{\partial F_x(u)}{\partial u}]\|]+KL[q(u)\|\|p(F_x(u))]$$ | -| [NVAE](https://0809zheng.github.io/2022/04/20/nvae.html)
引入自回归高斯模型 | $$\begin{aligned}&\mathbb{E}_{z \text{~} q(z\|x,y)} [-\log p(x\|z,y)]+KL[q(z\|x,y)\|\|p(z\|y)] \\ &p(z) = \prod_{l=1}^{L} p(z_l\|z_{\lt l}) ,q(z\|x) = \prod_{l=1}^{L} q(z_l\|z_{\lt l},x)\end{aligned}$$ | -| [EL-VAE](https://0809zheng.github.io/2022/04/12/msssim.html)
引入**MS-SSIM**损失 | $I_M(x,\hat{x})^{\alpha_M} \prod_{j=1}^{M} C_j(x,\hat{x})^{\beta_j} S_j(x,\hat{x})^{\gamma_j}+ KL[q(z\|x)\|\|p(z)]$ | -| [DFCVAE](https://0809zheng.github.io/2022/04/09/dfcvae.html)
引入特征感知损失 | $$\alpha \sum_{l=1}^{L}\frac{1}{2C^lH^lW^l}\sum_{c=1}^{C^l}\sum_{h=1}^{H^l}\sum_{w=1}^{W^l}(\Phi(x)_{c,h,w}^l-\Phi(\hat{x})_{c,h,w}^l)^2+\beta \cdot KL[q(z\|x)\|\|p(z)]$$ | -| [LogCosh VAE](https://0809zheng.github.io/2022/04/13/logcosh.html)
引入**log cosh**损失 | $$\frac{1}{a} \log(\frac{e^{a(x-\hat{x})}+e^{-a(x-\hat{x})}}{2})+\beta \cdot KL[q(z\|x)\|\|p(z)]$$ | -| [IWAE](https://0809zheng.github.io/2022/04/07/iwae.html)
紧凑变分下界 | $$\Bbb{E}_{z_1,z_2,\cdots z_K \text{~} q(z\|x)}[\log \frac{1}{K}\sum_{k=1}^{K}\frac{p(x,z_k)}{q(z_k\|x)}]$$ | -| [MIWAE](https://0809zheng.github.io/2022/04/08/miwae.html)
紧凑变分下界 | $$\frac{1}{M}\sum_{m=1}^{M} \Bbb{E}_{z_{m,1},z_{m,2},\cdots z_{m,K} \text{~} q(z_{m}\|x)}[\log \frac{1}{K}\sum_{k=1}^{K}\frac{p(x,z_{m,k})}{q(z_{m,k}\|x)}]$$ | -| [WAE](https://0809zheng.github.io/2022/04/04/wae.html)
引入**Wasserstein**距离 | $$\Bbb{E}_{x\text{~}p(z)}\Bbb{E}_{z \text{~} q(z\|x)} [c(x,p(x\|z))]+\lambda \cdot \mathcal{D}_Z(q(z),p(z))$$ | -| [SWAE](https://0809zheng.github.io/2022/04/14/swae.html)
引入**Sliced-Wasserstein**距离 | $$\Bbb{E}_{x\text{~}p(z)}\Bbb{E}_{z \text{~} q(z\|x)} [c(x,p(x\|z))]+\int_{\Bbb{S}^{d-1}} \mathcal{W}[\mathcal{R}_{q(z)}(\cdot ;\theta),\mathcal{R}_{p(z)}(\cdot ;\theta)] d\theta$$ | - -# ⚪ 参考文献 -- [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114):(arXiv1312)VAE的原始论文。 -- [From Autoencoder to Beta-VAE](https://lilianweng.github.io/posts/2018-08-12-vae/):Blog by Lilian Weng. -- [变分自编码器(一):原来是这么一回事](https://spaces.ac.cn/archives/5253):Blog by 苏剑林. -- [Recent Advances in Autoencoder-Based Representation Learning](https://arxiv.org/abs/1812.05069v1):(arXiv1812)一篇VAE综述。 -- [PyTorch-VAE: A Collection of Variational Autoencoders (VAE) in PyTorch.](https://github.com/AntixK/PyTorch-VAE):(github)VAE的PyTorch实现。 -- [Learning Structured Output Representation using Deep Conditional Generative Models](https://0809zheng.github.io/2022/04/02/cvae.html):(NeurIPS2015)CVAE: 使用深度条件生成模型学习结构化输出表示。 -- [Importance Weighted Autoencoders](https://0809zheng.github.io/2022/04/07/iwae.html):(arXiv1509)IWAE:重要性加权自编码器。 -- [Learning to Generate Images with Perceptual Similarity Metrics](https://0809zheng.github.io/2022/04/12/msssim.html):(arXiv1511)使用多尺度结构相似性度量MS-SSIM学习图像生成。 -- [Adversarial Autoencoders](https://0809zheng.github.io/2022/02/20/aae.html):(arXiv1511)AAE:对抗自编码器。 -- [Autoencoding beyond pixels using a learned similarity metric](https://0809zheng.github.io/2022/02/17/vaegan.html):(arXiv1512)VAE-GAN:结合VAE和GAN。 -- [Variational methods for Conditional Multimodal Learning: Generating Human Faces from Attributes](https://0809zheng.github.io/2022/04/03/cmma.html):(arXiv1603)CMMA: 条件多模态学习的变分方法。 -- [Deep Feature Consistent Variational Autoencoder](https://0809zheng.github.io/2022/04/09/dfcvae.html):(arXiv1610)DFCVAE:使用特征感知损失约束深度特征一致性。 -- [Categorical Reparameterization with Gumbel-Softmax](https://0809zheng.github.io/2022/04/10/catevae.html):(arXiv1611)使用Gumble-Softmax实现离散类别隐变量的重参数化。 -- [β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework](https://0809zheng.github.io/2020/12/02/bvae.html):(ICLR1704)β-VAE:学习变分自编码器隐空间的解耦表示。 -- [InfoVAE: Balancing Learning and Inference in Variational Autoencoders](https://0809zheng.github.io/2020/12/04/infovae.html):(arXiv1706)InfoVAE:平衡变分自编码器的学习和推断过程。 -- [Variational Inference of Disentangled Latent Concepts from Unlabeled Observations](https://0809zheng.github.io/2022/04/17/dipvae.html):(arXiv1711)DIP-VAE: 分离推断先验VAE。 -- [Wasserstein Auto-Encoders](https://0809zheng.github.io/2022/04/04/wae.html):(arXiv1711)WAE: 使用**Wasserstein**距离的变分自编码器。 -- [Neural Discrete Representation Learning](https://0809zheng.github.io/2020/11/10/vqvae.html):(arXiv1711)VQ-VAE:向量量化的变分自编码器。 -- [Tighter Variational Bounds are Not Necessarily Better](https://0809zheng.github.io/2022/04/08/miwae.html):(arXiv1802)MIWAE:紧凑的变分下界阻碍推理网络训练。 -- [Disentangling by Factorising](https://0809zheng.github.io/2022/04/15/factorvae.html):(arXiv1802)FactorVAE:通过分解特征表示的分布进行解耦。 -- [Isolating Sources of Disentanglement in Variational Autoencoders](https://0809zheng.github.io/2022/04/05/btcvae.html):(arXiv1802)β-TCVAE: 分离VAE解耦源中的全相关项。 -- [Understanding disentangling in β-VAE](https://0809zheng.github.io/2020/12/03/bvae2.html):(arXiv1804)使用信息瓶颈解释β-VAE的解耦表示能力。 -- [Structured Disentangled Representations](https://0809zheng.github.io/2022/04/16/hfvae.html):(arXiv1804)HFVAE:通过层级分解VAE实现结构化解耦表示。 -- [Learning Disentangled Joint Continuous and Discrete Representations](https://0809zheng.github.io/2022/04/11/jointvae.html):(arXiv1804)Joint VAE:学习解耦的联合连续和离散表示。 -- [Sliced-Wasserstein Autoencoder: An Embarrassingly Simple Generative Model](https://0809zheng.github.io/2022/04/14/swae.html):(arXiv1804)SWAE:引入Sliced-Wasserstein距离构造VAE。 -- [Hyperspherical Variational Auto-Encoders](https://0809zheng.github.io/2022/04/19/svae.html):(arXiv1804)Hyperspherical VAE: 为隐变量引入vMF分布。 -- [Temporal Difference Variational Auto-Encoder](https://0809zheng.github.io/2022/04/21/tdvae.html):(arXiv1806)TD-VAE: 时间差分变分自编码器。 -- [f-VAEs: Improve VAEs with Conditional Flows](https://0809zheng.github.io/2022/04/22/fvae.html):(arXiv1809)f-VAE: 基于流的变分自编码器。 -- [Log Hyperbolic Cosine Loss Improves Variational Auto-Encoder](https://0809zheng.github.io/2022/04/13/logcosh.html):(OpenReview2018)使用对数双曲余弦损失改进变分自编码器。 -- [Generating Diverse High-Fidelity Images with VQ-VAE-2](https://0809zheng.github.io/2020/11/11/vqvae2.html):(arXiv1906)VQ-VAE-2:改进VQ-VAE生成高保真度图像。 -- [A Batch Normalized Inference Network Keeps the KL Vanishing Away](https://0809zheng.github.io/2022/04/18/bnvae.html):(arXiv2004)BN-VAE: 通过批量归一化缓解KL散度消失问题。 -- [NVAE: A Deep Hierarchical Variational Autoencoder](https://0809zheng.github.io/2022/04/20/nvae.html):(arXiv2007)Nouveau VAE: 深度层次变分自编码器。 +--- +layout: post +title: '变分自编码器(Variational Autoencoder)' +date: 2022-04-01 +author: 郑之杰 +cover: 'https://pic.downk.cc/item/5fb715b8b18d627113d7ae88.jpg' +tags: 深度学习 +--- + +> Variational Autoencoder. + +本文目录: +1. 变分自编码器之“自编码器”:概率编码器与概率解码器 +2. 变分自编码器之“变分”:优化目标与重参数化 +3. 变分自编码器的各种变体 + + +# 1. 变分自编码器之“自编码器”:概率编码器与概率解码器 +**变分自编码器(Variational Autoencoder,VAE)**是一种深度生成模型,旨在学习已有数据集$$\{x_1,x_2,...,x_n\}$$的概率分布$p(x)$,并从数据分布中采样生成新的数据$\hat{x}$~$p(x)$。由于已有数据(也称**观测数据, observed data**)的概率分布形式是未知的,**VAE**把输入数据编码到**隐空间(latent space)**中,构造**隐变量(latent variable)**的概率分布$p(z)$,从隐变量中采样并重构新的数据,整个过程与自编码器类似。**VAE**的概率模型如下: + +$$ p(x) = \sum_{z}^{} p(x,z) = \sum_{z}^{}p(x|z)p(z) $$ + +如果人为指定隐变量的概率分布$p(z)$形式,则可以从中采样并通过解码器$p(x \| z)$(通常用神经网络拟合)生成新的数据。然而注意到此时隐变量的概率分布$p(z)$与输入数据无关,即给定一个输入数据$x_n$,从$p(z)$随机采样并重构为$\hat{x}_n$,将无法保证$x_n$与$\hat{x}_n$的对应性!此时生成模型常用的优化指标$Distance(x_n,\hat{x}_n)$等也无法使用。 + +在**VAE**中并不是**直接指定**隐变量的概率分布$p(z)$形式,而是为每个输入数据$x_n$指定一个**后验分布**$q(z \| x_n)$(通常为标准正态分布),则从该后验分布中采样并重构的$\hat{x}_n$对应于$x_n$。**VAE**指定后验分布$q(z \| x_n)$为标准正态分布$$\mathcal{N}(0,I)$$,则隐变量分布$p(z)$实际上也会是标准正态分布$$\mathcal{N}(0,I)$$: + +$$ p(z) = \sum_{x}^{} q(z|x)p(x) = \sum_{x}^{} \mathcal{N}(0,I)p(x)= \mathcal{N}(0,I)\sum_{x}^{} p(x)= \mathcal{N}(0,I) $$ + +**VAE**使用编码器(通常用神经网络)拟合后验分布$q(z \| x_n)$的均值$\mu_n$和方差$\sigma_n^2$(其数量由每批次样本量决定),通过训练使其接近标准正态分布$$\mathcal{N}(0,I)$$。实际上后验分布不可能**完全精确**地被拟合为标准正态分布,因为这样会使得$q(z \| x_n)$完全独立于输入数据$x_n$,从而使得重构效果极差。**VAE**的训练过程中隐含地存在着**对抗**的过程,最终使得$q(z \| x_n)$保留一定的输入数据$x_n$信息,并且对输入数据也具有一定的重构效果。 + +**VAE**的整体结构如下图所示。从给定数据中学习后验分布$q(z \| x)$的均值$\mu_n$和方差$\sigma_n^2$的过程称为**推断(inference)**,实现该过程的结构被称作**概率编码器(probabilistic encoder)**。从后验分布的采样结果中重构数据的过程$p(x \| z)$称为**生成(generation)**,实现该过程的结构被称作**概率解码器(probabilistic decoder)**。 + +![](https://pic.imgdb.cn/item/626b9206239250f7c5a98639.jpg) + +### ⚪讨论:后验分布可以选取其他分布吗? + +理论上,后验分布$q(z \| x_n)$可以选取任意可行的概率分布形式。然而从后续讨论中会发现对后验分布的约束是通过**KL**散度实现的,**KL**散度对于概率为$0$的点会发散,选择概率密度全局非负的标准正态分布$$\mathcal{N}(0,I)$$不会出现这种问题,且具有可以计算梯度的简洁的解析解。此外,由于服从正态分布的独立随机变量的和仍然是正态分布,因此隐空间中任意两点间的线性插值也是有意义的,并且可以通过线性插值获得一系列生成结果的展示。 + +### ⚪讨论:VAE的Bayesian解释 + +自编码器**AE**将观测数据$x$编码为特征向量$z$,每一个特征向量对应特征空间中的一个离散点,所有特征向量的分布是无序、散乱的,并且无法保证不存在特征向量的空间点能够重构出真实样本。**VAE**是**AE**的**Bayesian**形式,将特征向量看作随机变量,使其能够覆盖特征空间中的一片区域。进一步通过强迫所有数据的特征向量服从多维正态分布,从而解耦特征维度,使得特征的空间分布有序、规整。 + + + +# 2. 变分自编码器之“变分”:优化目标与重参数化 +**VAE**是一种隐变量模型$p(x,z)$,其优化目标为最大化观测数据的对数似然$\log p(x)=\log \sum_{z}^{} p(x,z)$。该问题是不可解的,因此采用[变分推断](https://0809zheng.github.io/2020/03/25/variational-inference.html)求解。变分推断的核心思想是引入一个新的分布$q(z \| x)$作为后验分布$p(z \| x)$的近似,从而构造**对数似然**$\log p(x)$的置信下界**ELBO**(也称**变分下界**, **variational lower bound**),通过最大化**ELBO**来代替最大化$\log p(x)$。采用[Jensen不等式](https://0809zheng.github.io/2022/07/20/jenson.html)可以快速推导**ELBO**的表达式: + +$$ \begin{aligned} \log p(x) &= \log \sum_{z}^{} p(x,z) = \log \sum_{z}^{} \frac{p(x,z)}{q(z|x)}q(z|x) \\ &= \log \Bbb{E}_{z \text{~} q(z|x)}[\frac{p(x,z)}{q(z|x)}] \geq \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{p(x,z)}{q(z|x)}] \end{aligned} $$ + +上式表明变分下界**ELBO**是原优化目标$\log p(x)$的一个下界,两者的差距可以通过对$\log p(x)$的另一种写法获得: + +$$ \begin{aligned} \log p(x) &= \sum_{z}^{} q(z|x)\log p(x)= \Bbb{E}_{z \text{~} q(z|x)}[\log p(x)]\\ &= \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{p(x,z)}{p(z|x)}] = \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{p(x,z)}{p(z|x)}\frac{q(z|x)}{q(z|x)}] \\ &= \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{p(x,z)}{q(z|x)}] + \Bbb{E}_{z \text{~} q(z|x)}[\log \frac{q(z|x)}{p(z|x)}] \end{aligned} $$ + +因此**VAE**的变分下界与原目标之间存在的**gap**为$\Bbb{E}_{z \text{~} q(z\|x)}[\log \frac{q(z\|x)}{p(z\|x)}]=KL(q(z\|x)\|\|p(z\|x))$。让**gap**为$0$的条件是$q(z\|x)=p(z\|x)$,即找到一个与真实后验分布$p(z\|x)$相同的分布$q(z\|x)$。然而$q(z\|x)$通常假设为较为简单的分布形式(如正态分布),不能拟合足够复杂的分布。因此**VAE**通常只是一个近似模型,优化的是代理(**surrogate**)目标,生成的图像比较模糊。 + +在**VAE**中,最大化**ELBO**等价于最小化如下损失函数: + +$$ \begin{aligned} \mathcal{L} &= -\mathbb{E}_{z \text{~} q(z|x)} [\log \frac{p(x,z)}{q(z|x)}] = -\mathbb{E}_{z \text{~} q(z|x)} [\log \frac{p(x|z)p(z)}{q(z|x)}] \\ &= -\mathbb{E}_{z \text{~} q(z|x)} [\log p(x | z)] - \mathbb{E}_{z \text{~}q(z|x)} [\log \frac{p(z)}{q(z|x)}] \\ &= \mathbb{E}_{z \text{~} q(z|x)} [-\log p(x | z)] + KL[q(z|x)||p(z)] \end{aligned} $$ + +直观上损失函数可以分成两部分:其中$\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x \| z)]$表示生成模型$p(x\|z)$的**重构损失**,$KL[q(z\|x)\|\|p(z)]$表示后验分布$q(z\|x)$的**正则化项**(**KL**损失)。这两个损失并不是独立的,因为重构损失很小表明$p(x\|z)$置信度较大,即解码器重构比较准确,则编码器$q(z\|x)$不会太随机(即应和$x$相关性较高),此时**KL**损失不会小;另一方面**KL**损失很小表明编码器$q(z\|x)$随机性较高(即和$x$无关),此时重构损失不可能小。因此**VAE**的损失隐含着对抗的过程,在优化过程中总损失减小才对应模型的收敛。下面分别讨论这两种损失的具体形式。 + +## (1) 后验分布$q(z\|x)$的正则化项 + +损失$KL[q(z\|x)\|\|p(z)]$衡量后验分布$q(z\|x)$和先验分布$p(z)$之间的**KL**散度。$q(z\|x)$优化的目标是趋近标准正态分布,此时$p(z)$指定为标准正态分布$z$~$$\mathcal{N}(0,I)$$。$q(z\|x)$通过神经网络进行拟合(即概率编码器),其形式人为指定为**多维对角正态分布** $$\mathcal{N}(\mu,\sigma^{2})$$。 + +由于两个分布都是正态分布,**KL**散度有闭式解(**closed-form solution**),计算如下: + +$$ \begin{aligned} KL[q(z|x)||p(z)] &= KL[\mathcal{N}(\mu,\sigma^{2})||\mathcal{N}(0,1)] \\ &= \int_{}^{} \frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}} \log \frac{\frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}}}{\frac{1}{\sqrt{2\pi}}e^{-\frac{x^2}{2}}} dx \\&= \int_{}^{} \frac{1}{\sqrt{2\pi\sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}} [-\frac{1}{2}\log \sigma^2 + \frac{x^2}{2}-\frac{(x-\mu)^2}{2\sigma^2}] dx \\ &= \frac{1}{2} (-\log \sigma^2 + \mu^2+\sigma^2-1) \end{aligned} $$ + +在实际实现时拟合$\log \sigma^2$而不是$\sigma^2$,因为$\sigma^2$总是非负的,需要增加激活函数进行限制;而$\log \sigma^2$的取值是任意的。**KL**损失的**Pytorch**实现如下: + +```python +kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) +``` + + +## (2) 生成模型$p(x\|z)$的重构损失 + +重构损失$\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x \| z)]$表示把观测数据映射到隐空间后再重构为原数据的过程。其中生成模型$p(x\|z)$也是通过神经网络进行拟合的(即概率解码器),根据所处理数据的类型不同,$p(x\|z)$应选择不同的分布形式。 + +### ⚪二值数据:伯努利分布 +如果观测数据$x$为二值数据(如二值图像),则生成模型$p(x\|z)$建模为伯努利分布: + +$$ p(x|z)= \begin{cases} \rho(z) , & x=1 \\ 1-\rho(z), & x=0 \end{cases} = (\rho(z))^{x}(1-\rho(z))^{1-x} $$ + +使用神经网络拟合参数$\rho$: + +$$ \rho = \mathop{\arg \max}_{\rho} \log p(x|z) = \mathop{\arg \min}_{\rho} -x\log \rho(z)-(1-x)\log(1-\rho(z)) $$ + +上式表示交叉熵损失函数,且$\rho(z)$需要经过**sigmoid**等函数压缩到$[0,1]$。 + +### ⚪一般数据:正态分布 + +对于一般的观测数据$x$,将生成模型$p(x\|z)$建模为具有固定方差$\sigma^2_0$的正态分布: + +$$ p(x|z) = \frac{1}{\sqrt{2\pi\sigma_0^2}}e^{-\frac{(x-\mu)^2}{2\sigma_0^2}} $$ + +使用神经网络拟合参数$\mu$: + +$$ \mu = \mathop{\arg \max}_{\mu} \log p(x|z) = \mathop{\arg \min}_{\mu} \frac{(x-\mu)^2}{2\sigma_0^2} $$ + +上式表示**均方误差(MSE)**,**Pytorch**实现如下: + +```python +recons_loss = F.mse_loss(recons, input, reduction = 'sum') +``` + +注意`reduction`参数可选`'sum'`和`'mean'`,应该使用`'sum'`,这使得损失函数计算与原式保持一致。笔者在实现时曾选用`'mean'`,导致即使训练损失有下降,也只能生成噪声图片,推测是因为取平均使重构损失误差占比过小,无法正常训练。 + + +## (3) 重参数化技巧 +**VAE**的损失函数如下: + +$$ \mathcal{L} = \mathbb{E}_{z \text{~} q(z|x)} [-\log p(x | z)] + KL[q(z|x)||p(z)] $$ + +其中期望$$\mathbb{E}_{z \text{~} q(z\|x)} [\cdot]$$表示从从$q(z\|x)$中采样$z$的过程。由于采样过程是不可导的,不能直接参与梯度传播,因此引入[重参数化(reparameterization)技巧](https://0809zheng.github.io/2022/04/24/repere.html)。 + +已经假设$z$~$q(z\|x)$服从$$\mathcal{N}(\mu,\sigma^{2})$$,则$\epsilon=\frac{z-\mu}{\sigma}$服从标准正态分布$$\mathcal{N}(0,I)$$。因此有如下关系: + +$$ z = \mu + \sigma \cdot \epsilon $$ + +则从$$\mathcal{N}(0,I)$$中采样$\epsilon$,再经过参数变换构造$z$,可使得采样操作不用参与梯度下降,从而实现模型端到端的训练。 + +![](https://pic.imgdb.cn/item/627a22c909475431292e6d8d.jpg) + +在实现时对于每个样本只进行一次采样,采样的充分性是通过足够多的批量样本与训练轮数来保证的。则损失函数也可写作: + +$$ \mathcal{L} = -\log p(x | z) + KL[q(z|x)||p(z)], \quad z \text{~} q(z|x)$$ + +重参数化技巧的**Pytorch**实现如下: + +```python +def reparameterize(mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return mu + eps * std +``` + +### ⚪讨论:VAE的另一种建模方式 +对于一批已有样本,代表一个真实但形式未知的概率分布$\tilde{p}(x)$,可以构建一个带参数$\phi$的后验分布$q_{\phi}(z\|x)$,从而组成联合分布$q(x,z)=\tilde{p}(x)q_{\phi}(z\|x)$。如果人为定义一个先验分布$p(z)$,并构建一个带参数$\theta$的生成分布$p_{\theta}(x\|z)$,则可以构造另一个联合分布$p(x,z)=p(z)p_{\theta}(x\|z)$。**VAE**的目的是联合分布$q(x,z),p(x,z)$足够接近,因此最小化两者之间的**KL**散度: + +$$ \begin{aligned} KL(q(x,z)||p(x,z)) &= \mathbb{E}_{q(x,z)}[\log \frac{q(x,z)}{p(x,z)}] = \mathbb{E}_{\tilde{p}(x)} [\mathbb{E}_{q_{\phi}(z|x)}[\log \frac{\tilde{p}(x)q_{\phi}(z|x)}{p(z)p_{\theta}(x|z)}]] \\& = \mathbb{E}_{\tilde{p}(x)} [\mathbb{E}_{q_{\phi}(z|x)}[-\log p_{\theta}(x|z)] + \mathbb{E}_{q_{\phi}(z|x)}[\log \frac{q_{\phi}(z|x)}{p(z)}] + \mathbb{E}_{q_{\phi}(z|x)}[\log \tilde{p}(x)]] \\ &= \mathbb{E}_{\tilde{p}(x)} [\mathbb{E}_{q_{\phi}(z|x)}[-\log p_{\theta}(x|z)] + KL(q_{\phi}(z|x)||p(z)) + Const.] \end{aligned} $$ + +### ⚪讨论:KL散度与互信息 + +上述联合分布的**KL**散度也可写作: + +$$ \begin{aligned} KL(q(x,z)||p(x,z)) & = \mathbb{E}_{\tilde{p}(x)} [\mathbb{E}_{q_{\phi}(z|x)}[\log \frac{\tilde{p}(x)q_{\phi}(z|x)}{p(z)p_{\theta}(x|z)}]] \\& = \mathbb{E}_{\tilde{p}(x)} [\mathbb{E}_{q_{\phi}(z|x)}[\log \frac{q_{\phi}(z|x)}{p(z)}]] - \mathbb{E}_{\tilde{p}(x)} [\mathbb{E}_{q_{\phi}(z|x)}[\log \frac{p_{\theta}(x|z)}{\tilde{p}(x)}]] \\& = \mathbb{E}_{\tilde{p}(x)} [KL(q_{\phi}(z|x)||p(z))] - \mathbb{E}_{\tilde{p}(x)} [\mathbb{E}_{q_{\phi}(z|x)}[\log \frac{p_{\theta}(x,z)}{\tilde{p}(x)p(z)}]] \end{aligned} $$ + +其中第一项为隐变量$z$的后验分布与先验分布之间的**KL**散度,第二项为观测变量$x$与为隐变量$z$之间的点互信息。因此**VAE**的优化目标也可以解释为最小化隐变量**KL**散度的同时最大化隐变量与观测变量的互信息。 + +# 3. 变分自编码器的各种变体 + +**VAE**的损失函数共涉及三个不同的概率分布:由概率编码器表示的后验分布$q(z\|x)$、隐变量的先验分布$p(z)$以及由概率解码器表示的生成分布$p(x\|z)$。对**VAE**的各种改进可以落脚于对这些概率分布的改进: + +- 后验分布$q(z\|x)$:后验分布为模型引入了正则化;一种改进思路是通过调整后验分布的正则化项增强模型的解耦能力(如$\beta$-**VAE**, **Disentangled** $\beta$-**VAE**, **InfoVAE**, **DIP-VAE**, **FactorVAE**, $\beta$-**TCVAE**, **HFVAE**)。 +- 先验分布$p(z)$:先验分布描绘了隐变量分布的隐空间;一种改进思路是通过引入标签实现半监督学习(如**CVAE**, **CMMA**);一种改进思路是通过对隐变量离散化实现聚类或分层特征表示(如**Categorical VAE**, **Joint VAE**, **VQ-VAE**, **VQ-VAE-2**, **FSQ**);一种改进思路是更换隐变量的概率分布形式(如**Hyperspherical VAE**, **TD-VAE**, **f-VAE**, **NVAE**)。 +- 生成分布$p(x\|z)$:生成分布代表模型的数据重构能力;一种改进思路是将均方误差损失替换为其他损失(如**EL-VAE**, **DFCVAE**, **LogCosh VAE**)。 +- 改进整体损失函数:也有方法通过调整整体损失改进模型,如紧凑变分下界(如**IWAE**, **MIWAE**)或引入**Wasserstein**距离(如**WAE**, **SWAE**)。 +- 改进模型结构:如[BN-VAE](https://0809zheng.github.io/2022/04/18/bnvae.html)通过引入**BatchNorm**缓解**KL**散度消失问题(指较强的解码器允许训练时**KL**散度项$KL[q(z\|x)\|\|p(z)]=0$);引入对抗训练(如[AAE](https://0809zheng.github.io/2022/02/20/aae.html), [VAE-GAN](https://0809zheng.github.io/2022/02/17/vaegan.html))。 + + + + +| 方法 | 损失函数 | +| :---: | :---: | +| VAE | $\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)]+KL[q(z\|x)\|\|p(z)]$ | +| [CVAE](https://0809zheng.github.io/2022/04/02/cvae.html)
引入条件 | $\mathbb{E}_{z \text{~} q(z\|x,y)} [-\log p(x\|z,y)]+KL[q(z\|x,y)\|\|p(z\|y)]$ | +| [CMMA](https://0809zheng.github.io/2022/04/03/cmma.html)
隐变量$z$由标签$y$决定 | $\mathbb{E}_{z \text{~} q(z\|x,y)} [-\log p(x\|z)]+KL[q(z\|x,y)\|\|p(z\|y)]$ | +| [β-VAE](https://0809zheng.github.io/2020/12/02/bvae.html)
特征解耦 | $\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)]+\beta \cdot KL[q(z\|x)\|\|p(z)]$ | +| [Disentangled β-VAE](https://0809zheng.github.io/2020/12/03/bvae2.html)
特征解耦 | $\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)]+\gamma \cdot \|KL[q(z\|x)\|\|p(z)]-C\|$ | +| [InfoVAE](https://0809zheng.github.io/2020/12/04/infovae.html)
特征解耦 | $\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)]+(1-\alpha) \cdot KL[q(z\|x)\|\|p(z)]+(\alpha+\lambda-1)\cdot \mathcal{D}_Z(q(z),p(z))$ | +| [DIP-VAE](https://0809zheng.github.io/2022/04/17/dipvae.html)
分离推断先验 | $$\begin{aligned} &\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)]+KL[q(z\|x)\|\|p(z)]\\&+\lambda_{od} \sum_{i \ne j} [\text{Cov}_{q(z)}[z]]^2_{ij}+\lambda_{d} \sum_{i} ([\text{Cov}_{q(z)}[z]]_{ii}-1)^2 \end{aligned}$$ | +| [FactorVAE](https://0809zheng.github.io/2022/04/15/factorvae.html)
特征解耦 | $$\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)]+KL[q(z\|x)\|\|p(z)] +\gamma KL(q(z)\|\|\prod_{j}q(z_j))$$ | +| [β-TCVAE](https://0809zheng.github.io/2022/04/05/btcvae.html)
分离全相关项 | $$\begin{aligned} & \mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)]+\alpha KL(q(z,x)\|\|q(z)p(x)) \\ & +\beta KL(q(z)\|\|\prod_{j}q(z_j)) +\gamma \sum_{j}KL(q(z_j)\|\|p(z_j)) \end{aligned}$$ | +| [HFVAE](https://0809zheng.github.io/2022/04/16/hfvae.html)
隐变量特征分组 | $$\begin{aligned}&\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)] + \sum_{i}KL[q(z_{i})\|\|p(z_{i})] + \alpha KL(q(z,x)\|\|q(z)p(x)) \\ &+\beta \Bbb{E}_{q(z)}[\log \frac{q(z)}{\prod_{g}q(z^g)}-\log \frac{p(z)}{\prod_{g}p(z^g)}] + \gamma \sum_{g} \Bbb{E}_{q(z^g)}[\log \frac{q(z^g)}{\prod_{j}q(z^g_j)}-\log \frac{p(z^g)}{\prod_{j}p(z^g_j)}]\end{aligned}$$ | +| [Categorical VAE](https://0809zheng.github.io/2022/04/10/catevae.html)
离散隐变量: **Gumbel Softmax** | $\mathbb{E}_{z \text{~} q(c\|x)} [-\log p(x\|c)]+ KL[q(c\|x)\|\|p(c)]$ | +| [Joint VAE](https://0809zheng.github.io/2022/04/11/jointvae.html)
连续+离散隐变量 | $\mathbb{E}_{z,c \text{~} q(z,c\|x)} [-\log p(x\|z,c)]+\gamma_z \cdot \|KL[q(z\|x)\|\|p(z)]-C_z\|+\gamma_c \cdot \|KL[q(c\|x)\|\|p(c)]-C_c\|$ | +| [VQ-VAE](https://0809zheng.github.io/2020/11/10/vqvae.html)
向量量化隐变量 | $\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z_e + sg[z_q-z_e])] + \|\| sg[z_e] - z_q \|\|_2^2 + \beta \|\| z_e - sg[z_q] \|\|_2^2$ | +| [VQ-VAE-2](https://0809zheng.github.io/2020/11/11/vqvae2.html)
VQ-VAE分层 | 同上,$z=\{z^{bottom},z^{top}\}$ | +| [FSQ](https://0809zheng.github.io/2023/09/27/fsq.html)
有限标量量化 | $\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z + sg[\text{round} \left( \lfloor \frac{L}{2} \rfloor \tanh(z) \right)-z])]$ | +| [Hyperspherical VAE](https://0809zheng.github.io/2022/04/19/svae.html)
引入**vMF**分布 | $$\mathbb{E}_{z \text{~} q(z\|x)} [-\log p(x\|z)],\quad p(z)\text{~}C_{d,\kappa} e^{\kappa <\mu(x),z>}$$ | +| [TD-VAE](https://0809zheng.github.io/2022/04/21/tdvae.html)
**Markov**链状态空间 | $$\begin{aligned} \Bbb{E}_{(z_{t_1},z_{t_2}) \text{~} q}[ & -\log p_D(x_{t_2}\|z_{t_2}) -\log p_B(z_{t_1}\|b_{t_1}) -\log p_T(z_{t_2}\|z_{t_1}) \\ & +\log p_B(z_{t_2}\|b_{t_2})+\log q_S(z_{t_1}\|z_{t_2},b_{t_1},b_{t_2})] \end{aligned}$$ | +| [f-VAE](https://0809zheng.github.io/2022/04/22/fvae.html)
引入流模型 | $$\mathbb{E}_{u \text{~} q(u)} [-\log p(x\|F_x(u))-\log \|\det[\frac{\partial F_x(u)}{\partial u}]\|]+KL[q(u)\|\|p(F_x(u))]$$ | +| [NVAE](https://0809zheng.github.io/2022/04/20/nvae.html)
引入自回归高斯模型 | $$\begin{aligned}&\mathbb{E}_{z \text{~} q(z\|x,y)} [-\log p(x\|z,y)]+KL[q(z\|x,y)\|\|p(z\|y)] \\ &p(z) = \prod_{l=1}^{L} p(z_l\|z_{\lt l}) ,q(z\|x) = \prod_{l=1}^{L} q(z_l\|z_{\lt l},x)\end{aligned}$$ | +| [EL-VAE](https://0809zheng.github.io/2022/04/12/msssim.html)
引入**MS-SSIM**损失 | $I_M(x,\hat{x})^{\alpha_M} \prod_{j=1}^{M} C_j(x,\hat{x})^{\beta_j} S_j(x,\hat{x})^{\gamma_j}+ KL[q(z\|x)\|\|p(z)]$ | +| [DFCVAE](https://0809zheng.github.io/2022/04/09/dfcvae.html)
引入特征感知损失 | $$\alpha \sum_{l=1}^{L}\frac{1}{2C^lH^lW^l}\sum_{c=1}^{C^l}\sum_{h=1}^{H^l}\sum_{w=1}^{W^l}(\Phi(x)_{c,h,w}^l-\Phi(\hat{x})_{c,h,w}^l)^2+\beta \cdot KL[q(z\|x)\|\|p(z)]$$ | +| [LogCosh VAE](https://0809zheng.github.io/2022/04/13/logcosh.html)
引入**log cosh**损失 | $$\frac{1}{a} \log(\frac{e^{a(x-\hat{x})}+e^{-a(x-\hat{x})}}{2})+\beta \cdot KL[q(z\|x)\|\|p(z)]$$ | +| [IWAE](https://0809zheng.github.io/2022/04/07/iwae.html)
紧凑变分下界 | $$\Bbb{E}_{z_1,z_2,\cdots z_K \text{~} q(z\|x)}[\log \frac{1}{K}\sum_{k=1}^{K}\frac{p(x,z_k)}{q(z_k\|x)}]$$ | +| [MIWAE](https://0809zheng.github.io/2022/04/08/miwae.html)
紧凑变分下界 | $$\frac{1}{M}\sum_{m=1}^{M} \Bbb{E}_{z_{m,1},z_{m,2},\cdots z_{m,K} \text{~} q(z_{m}\|x)}[\log \frac{1}{K}\sum_{k=1}^{K}\frac{p(x,z_{m,k})}{q(z_{m,k}\|x)}]$$ | +| [WAE](https://0809zheng.github.io/2022/04/04/wae.html)
引入**Wasserstein**距离 | $$\Bbb{E}_{x\text{~}p(z)}\Bbb{E}_{z \text{~} q(z\|x)} [c(x,p(x\|z))]+\lambda \cdot \mathcal{D}_Z(q(z),p(z))$$ | +| [SWAE](https://0809zheng.github.io/2022/04/14/swae.html)
引入**Sliced-Wasserstein**距离 | $$\Bbb{E}_{x\text{~}p(z)}\Bbb{E}_{z \text{~} q(z\|x)} [c(x,p(x\|z))]+\int_{\Bbb{S}^{d-1}} \mathcal{W}[\mathcal{R}_{q(z)}(\cdot ;\theta),\mathcal{R}_{p(z)}(\cdot ;\theta)] d\theta$$ | + +# ⚪ 参考文献 +- [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114):(arXiv1312)VAE的原始论文。 +- [From Autoencoder to Beta-VAE](https://lilianweng.github.io/posts/2018-08-12-vae/):Blog by Lilian Weng. +- [变分自编码器(一):原来是这么一回事](https://spaces.ac.cn/archives/5253):Blog by 苏剑林. +- [Recent Advances in Autoencoder-Based Representation Learning](https://arxiv.org/abs/1812.05069v1):(arXiv1812)一篇VAE综述。 +- [PyTorch-VAE: A Collection of Variational Autoencoders (VAE) in PyTorch.](https://github.com/AntixK/PyTorch-VAE):(github)VAE的PyTorch实现。 +- [Learning Structured Output Representation using Deep Conditional Generative Models](https://0809zheng.github.io/2022/04/02/cvae.html):(NeurIPS2015)CVAE: 使用深度条件生成模型学习结构化输出表示。 +- [Importance Weighted Autoencoders](https://0809zheng.github.io/2022/04/07/iwae.html):(arXiv1509)IWAE:重要性加权自编码器。 +- [Learning to Generate Images with Perceptual Similarity Metrics](https://0809zheng.github.io/2022/04/12/msssim.html):(arXiv1511)使用多尺度结构相似性度量MS-SSIM学习图像生成。 +- [Adversarial Autoencoders](https://0809zheng.github.io/2022/02/20/aae.html):(arXiv1511)AAE:对抗自编码器。 +- [Autoencoding beyond pixels using a learned similarity metric](https://0809zheng.github.io/2022/02/17/vaegan.html):(arXiv1512)VAE-GAN:结合VAE和GAN。 +- [Variational methods for Conditional Multimodal Learning: Generating Human Faces from Attributes](https://0809zheng.github.io/2022/04/03/cmma.html):(arXiv1603)CMMA: 条件多模态学习的变分方法。 +- [Deep Feature Consistent Variational Autoencoder](https://0809zheng.github.io/2022/04/09/dfcvae.html):(arXiv1610)DFCVAE:使用特征感知损失约束深度特征一致性。 +- [Categorical Reparameterization with Gumbel-Softmax](https://0809zheng.github.io/2022/04/10/catevae.html):(arXiv1611)使用Gumble-Softmax实现离散类别隐变量的重参数化。 +- [β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework](https://0809zheng.github.io/2020/12/02/bvae.html):(ICLR1704)β-VAE:学习变分自编码器隐空间的解耦表示。 +- [InfoVAE: Balancing Learning and Inference in Variational Autoencoders](https://0809zheng.github.io/2020/12/04/infovae.html):(arXiv1706)InfoVAE:平衡变分自编码器的学习和推断过程。 +- [Variational Inference of Disentangled Latent Concepts from Unlabeled Observations](https://0809zheng.github.io/2022/04/17/dipvae.html):(arXiv1711)DIP-VAE: 分离推断先验VAE。 +- [Wasserstein Auto-Encoders](https://0809zheng.github.io/2022/04/04/wae.html):(arXiv1711)WAE: 使用**Wasserstein**距离的变分自编码器。 +- [Neural Discrete Representation Learning](https://0809zheng.github.io/2020/11/10/vqvae.html):(arXiv1711)VQ-VAE:向量量化的变分自编码器。 +- [Tighter Variational Bounds are Not Necessarily Better](https://0809zheng.github.io/2022/04/08/miwae.html):(arXiv1802)MIWAE:紧凑的变分下界阻碍推理网络训练。 +- [Disentangling by Factorising](https://0809zheng.github.io/2022/04/15/factorvae.html):(arXiv1802)FactorVAE:通过分解特征表示的分布进行解耦。 +- [Isolating Sources of Disentanglement in Variational Autoencoders](https://0809zheng.github.io/2022/04/05/btcvae.html):(arXiv1802)β-TCVAE: 分离VAE解耦源中的全相关项。 +- [Understanding disentangling in β-VAE](https://0809zheng.github.io/2020/12/03/bvae2.html):(arXiv1804)使用信息瓶颈解释β-VAE的解耦表示能力。 +- [Structured Disentangled Representations](https://0809zheng.github.io/2022/04/16/hfvae.html):(arXiv1804)HFVAE:通过层级分解VAE实现结构化解耦表示。 +- [Learning Disentangled Joint Continuous and Discrete Representations](https://0809zheng.github.io/2022/04/11/jointvae.html):(arXiv1804)Joint VAE:学习解耦的联合连续和离散表示。 +- [Sliced-Wasserstein Autoencoder: An Embarrassingly Simple Generative Model](https://0809zheng.github.io/2022/04/14/swae.html):(arXiv1804)SWAE:引入Sliced-Wasserstein距离构造VAE。 +- [Hyperspherical Variational Auto-Encoders](https://0809zheng.github.io/2022/04/19/svae.html):(arXiv1804)Hyperspherical VAE: 为隐变量引入vMF分布。 +- [Temporal Difference Variational Auto-Encoder](https://0809zheng.github.io/2022/04/21/tdvae.html):(arXiv1806)TD-VAE: 时间差分变分自编码器。 +- [f-VAEs: Improve VAEs with Conditional Flows](https://0809zheng.github.io/2022/04/22/fvae.html):(arXiv1809)f-VAE: 基于流的变分自编码器。 +- [Log Hyperbolic Cosine Loss Improves Variational Auto-Encoder](https://0809zheng.github.io/2022/04/13/logcosh.html):(OpenReview2018)使用对数双曲余弦损失改进变分自编码器。 +- [Generating Diverse High-Fidelity Images with VQ-VAE-2](https://0809zheng.github.io/2020/11/11/vqvae2.html):(arXiv1906)VQ-VAE-2:改进VQ-VAE生成高保真度图像。 +- [A Batch Normalized Inference Network Keeps the KL Vanishing Away](https://0809zheng.github.io/2022/04/18/bnvae.html):(arXiv2004)BN-VAE: 通过批量归一化缓解KL散度消失问题。 +- [NVAE: A Deep Hierarchical Variational Autoencoder](https://0809zheng.github.io/2022/04/20/nvae.html):(arXiv2007)Nouveau VAE: 深度层次变分自编码器。 +- [Finite Scalar Quantization: VQ-VAE Made Simple](https://0809zheng.github.io/2023/09/27/fsq.html):(arXiv2309)有限标量量化:简化向量量化的变分自编码器。