# VQ-GAN论文阅读+代码解析

> 第一次写：0928
> 
> 第二次补充： 2025-10-08 19:09:54 Wednesday

重要观点：
1. **VQ-GAN中的Transformer属于通道注意力结构**:其设计融合了CNN特征提取的思想。注意力并非直接在像素层面展开，而是在卷积网络提取的多通道特征图上进行建模，从而在保持局部空间一致性的同时，引入全局依赖关系。

2.  **VQ-GAN能保留丰富图像信息并实现高质量重建的关键在于其token（码本向量）设计**:这些离散token并非仅表示某个像素块的颜色，而是通过多层卷积编码器（encoder）下采样后获得的高层语义表征，因而每个token都蕴含其邻域的结构与语义上下文。这种“上下文丰富”的embedding token使得Transformer阶段能够理解全局物体形状、边界与语义关系，从而在合成时生成结构一致、内容真实的图像。





## 文章动机

为了解决 Transformer **在图像生成领域的两大根本性瓶颈**：，而提出将卷积网络的高效局部建模与Transformer的强大全局建模相结合，使Transformer能在高分辨率图像生成中既高效又具有全局语义一致性。
### **1️⃣ 瓶颈一：像素级Transformer的复杂度太高**

* 直接在像素空间上对图像进行自回归建模（如 PixelRNN、Image Transformer）会导致：
  $$\text{复杂度} = O(H^2W^2)$$
  也就是随分辨率平方级增长。
* 这样在高分辨率（例如 512×512 甚至更大）时几乎无法训练或采样。


### **2️⃣ 瓶颈二：像素缺乏“上下文语义”**

* 每个像素或小patch独立编码，Transformer要自己“学习”全局结构和语义（比如物体的形状、轮廓），效率极低；
* 缺乏归纳偏置（inductive bias），导致训练不稳定、生成质量低。


### 💡 提出的方案

> **“让Transformer不再直接看像素，而是看语义token。”**

* 通过 **VQ-GAN** 的编码器（CNN结构）先将图像压缩成**上下文丰富的离散token**（即视觉词汇表中的code）。
* 这些token不是原始像素，而是融合了局部结构与语义的高层特征。
* 最后Transformer可以利用这些**低维、语义化的token序列**上进行自回归建模。



## GroupNorm？

- GroupNorm 将输入特征图的通道（Channel）划分为若干个小组（Group），在每个小组内部独立进行归一化操作（计算均值和方差，并通过缩放和平移调整分布）。
- 好处是GroupNorm 通过按通道分组进行归一化，摆脱了对批量大小的依赖，在小批量场景下比 BatchNorm 更鲁棒。

In [56]:
import math
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from icecream import ic
def Normalize(in_channels, num_groups=32):
    return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)


def nonlinearity(x):
    # swish
    return x*torch.sigmoid(x)

class Downsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            # no asymmetric padding in torch conv, must do it ourselves
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=2,
                                        padding=0)

    def forward(self, x):
        if self.with_conv:
            pad = (0,1,0,1)
            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
            x = self.conv(x)
        else:
            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
        return x
class Upsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, x):
        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
        if self.with_conv:
            x = self.conv(x)
        return x


## 通道注意力

**通道注意力是什么?**

通道注意力通过对各通道进行加权组合，从整体上突出重要通道、抑制次要通道；若从单个通道角度来看，其效果可等价为一次仿射变换（缩放与平移）。



**不同任务下的注意力机制含义?**

- NLP 任务：主要在序列维度上建模（词与词之间的关系），注意力机制用于突出文本序列中关键的 token。

- 图像任务：注意力可以作用于不同维度
    - 通道注意力：强调重要通道的信息，抑制次要通道。
    - 空间注意力：突出不同 patch 或像素的位置重要性。




**位置编码问题**

- 文本注意力：一维位置编码是必须的，用于保留序列顺序信息。

- 图像空间注意力：二维位置编码通常是必须的，尤其在全局注意力或 Transformer 架构中，用于保留像素或 patch 的空间关系。

- 图像通道注意力：输入通常是经过卷积提取的特征图，已经隐含空间信息，因此不需要额外的位置编码。



In [51]:
# pytorch_diffusion + derived encoder decoder

class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)


    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)
        ic(q.shape)

        # compute attention
        b,c,h,w = q.shape
        q = q.reshape(b,c,h*w)
        q = q.permute(0,2,1)   # b,hw,c
        k = k.reshape(b,c,h*w) # b,c,hw
        ic(q.shape)
        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        ic(w_.shape)
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values
        v = v.reshape(b,c,h*w)
        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
        h_ = h_.reshape(b,c,h,w)

        h_ = self.proj_out(h_)

        return x+h_



def make_attn(in_channels, attn_type="vanilla"):
    assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
    print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
    if attn_type == "vanilla":
        return AttnBlock(in_channels)
    elif attn_type == "none":
        return nn.Identity(in_channels)
    else:
        return LinAttnBlock(in_channels)


img=torch.rand(2,512,16,16)
attn=AttnBlock(img.shape[1])

out=attn(img)
ic(out.shape)


[38;5;247mic[39m[38;5;245m|[39m[38;5;245m [39m[38;5;247mq[39m[38;5;245m.[39m[38;5;247mshape[39m[38;5;245m:[39m[38;5;245m [39m[38;5;247mtorch[39m[38;5;245m.[39m[38;5;247mSize[39m[38;5;245m([39m[38;5;245m[[39m[38;5;36m2[39m[38;5;245m,[39m[38;5;245m [39m[38;5;36m512[39m[38;5;245m,[39m[38;5;245m [39m[38;5;36m16[39m[38;5;245m,[39m[38;5;245m [39m[38;5;36m16[39m[38;5;245m][39m[38;5;245m)[39m
[38;5;247mic[39m[38;5;245m|[39m[38;5;245m [39m[38;5;247mq[39m[38;5;245m.[39m[38;5;247mshape[39m[38;5;245m:[39m[38;5;245m [39m[38;5;247mtorch[39m[38;5;245m.[39m[38;5;247mSize[39m[38;5;245m([39m[38;5;245m[[39m[38;5;36m2[39m[38;5;245m,[39m[38;5;245m [39m[38;5;36m256[39m[38;5;245m,[39m[38;5;245m [39m[38;5;36m512[39m[38;5;245m][39m[38;5;245m)[39m
[38;5;247mic[39m[38;5;245m|[39m[38;5;245m [39m[38;5;247mw_[39m[38;5;245m.[39m[38;5;247mshape[39m[38;5;245m:[39m[38;5;245m [39m[38;5;247mtorch[39m[38;5;245m

torch.Size([2, 512, 16, 16])

## 残差网络块

In [52]:
class ResnetBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
                 dropout, temb_channels=512):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        self.norm1 = Normalize(in_channels)
        self.conv1 = torch.nn.Conv2d(in_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if temb_channels > 0:
            self.temb_proj = torch.nn.Linear(temb_channels,
                                             out_channels)
        self.norm2 = Normalize(out_channels)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                self.conv_shortcut = torch.nn.Conv2d(in_channels,
                                                     out_channels,
                                                     kernel_size=3,
                                                     stride=1,
                                                     padding=1)
            else:
                self.nin_shortcut = torch.nn.Conv2d(in_channels,
                                                    out_channels,
                                                    kernel_size=1,
                                                    stride=1,
                                                    padding=0)

    def forward(self, x, temb):
        h = x
        h = self.norm1(h)
        h = nonlinearity(h)
        h = self.conv1(h)

        if temb is not None:
            h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]

        h = self.norm2(h)
        h = nonlinearity(h)
        h = self.dropout(h)
        h = self.conv2(h)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)
            else:
                x = self.nin_shortcut(x)

        return x+h

In [53]:
# pytorch_diffusion + derived encoder decoder

class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)


    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        b,c,h,w = q.shape
        q = q.reshape(b,c,h*w)
        q = q.permute(0,2,1)   # b,hw,c
        k = k.reshape(b,c,h*w) # b,c,hw
        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values
        v = v.reshape(b,c,h*w)
        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
        h_ = h_.reshape(b,c,h,w)

        h_ = self.proj_out(h_)

        return x+h_

## Encoder网络结构构成


1. 将输入图片转成等大小多通道的初始特征图（通道数:ch=128）
2. 下采样阶段，
    - 包含3个子阶段，每个子阶段的通道数为ch_mult[i]*ch，每次分辨率下降1倍
    - 即chu_mult: [1,2,4]，则通道数变化：ch=128->128->256->512
    

In [45]:


class Encoder(nn.Module):
    def __init__(self, *, ch=1, out_ch=3, ch_mult=(1,2,4,8), num_res_blocks=2,
                 attn_resolutions=1, dropout=0.0, resamp_with_conv=True, in_channels=2,
                 resolution=1, z_channels=2, double_z=True, use_linear_attn=False, attn_type="vanilla",
                 **ignore_kwargs):
        super().__init__()
        if use_linear_attn: attn_type = "linear"
        ic(attn_type)
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels

        # downsampling
        self.conv_in = torch.nn.Conv2d(in_channels,
                                       self.ch,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        curr_res = resolution
        in_ch_mult = (1,)+tuple(ch_mult)
        self.in_ch_mult = in_ch_mult
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch*in_ch_mult[i_level]
            block_out = ch*ch_mult[i_level]
            ic(i_level)
            ic(block_in)
            ic(block_out)
            ic(in_ch_mult,ch_mult)
            ic(self.num_res_blocks)
            
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions-1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2
            self.down.append(down)
        ic(self.down)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        ic(block_in)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        
        ic(self.mid)

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        2*z_channels if double_z else z_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)
        ic(self.conv_out)

    def forward(self, x):
        # timestep embedding
        temb = None

        # downsampling
        hs = [self.conv_in(x)]
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1], temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                hs.append(h)
            if i_level != self.num_resolutions-1:
                hs.append(self.down[i_level].downsample(hs[-1]))

        # middle
        h = hs[-1]
        h = self.mid.block_1(h, temb)
        ic(h.shape)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)

        # end
        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        return h

ddconfig = {
    "double_z": False,
    "z_channels": 3,
    "resolution": 64,
    "in_channels": 1,
    "out_ch": 1,
    "ch": 128,
    "ch_mult": [1, 2, 4],
    "num_res_blocks": 3,
    "attn_resolutions": [],
    "dropout": 0.0
}

encoder=Encoder(**ddconfig)

random_input = torch.randn(2, ddconfig["in_channels"], ddconfig["resolution"], ddconfig["resolution"])

# 前向传播
output = encoder(random_input)


[38;5;247mic[39m[38;5;245m|[39m[38;5;245m [39m[38;5;247mattn_type[39m[38;5;245m:[39m[38;5;245m [39m[38;5;36m'[39m[38;5;36mvanilla[39m[38;5;36m'[39m
[38;5;247mic[39m[38;5;245m|[39m[38;5;245m [39m[38;5;247mi_level[39m[38;5;245m:[39m[38;5;245m [39m[38;5;36m0[39m
[38;5;247mic[39m[38;5;245m|[39m[38;5;245m [39m[38;5;247mblock_in[39m[38;5;245m:[39m[38;5;245m [39m[38;5;36m128[39m
[38;5;247mic[39m[38;5;245m|[39m[38;5;245m [39m[38;5;247mblock_out[39m[38;5;245m:[39m[38;5;245m [39m[38;5;36m128[39m
[38;5;247mic[39m[38;5;245m|[39m[38;5;245m [39m[38;5;247min_ch_mult[39m[38;5;245m:[39m[38;5;245m [39m[38;5;245m([39m[38;5;36m1[39m[38;5;245m,[39m[38;5;245m [39m[38;5;36m1[39m[38;5;245m,[39m[38;5;245m [39m[38;5;36m2[39m[38;5;245m,[39m[38;5;245m [39m[38;5;36m4[39m[38;5;245m)[39m[38;5;245m,[39m[38;5;245m [39m[38;5;247mch_mult[39m[38;5;245m:[39m[38;5;245m [39m[38;5;245m[[39m[38;5;36m1[39m[38;5;245m,

making attention of type 'vanilla' with 512 in_channels


[38;5;247mic[39m[38;5;245m|[39m[38;5;245m [39m[38;5;247mh[39m[38;5;245m.[39m[38;5;247mshape[39m[38;5;245m:[39m[38;5;245m [39m[38;5;247mtorch[39m[38;5;245m.[39m[38;5;247mSize[39m[38;5;245m([39m[38;5;245m[[39m[38;5;36m2[39m[38;5;245m,[39m[38;5;245m [39m[38;5;36m512[39m[38;5;245m,[39m[38;5;245m [39m[38;5;36m16[39m[38;5;245m,[39m[38;5;245m [39m[38;5;36m16[39m[38;5;245m][39m[38;5;245m)[39m
[38;5;247mic[39m[38;5;245m|[39m[38;5;245m [39m[38;5;247mq[39m[38;5;245m.[39m[38;5;247mshape[39m[38;5;245m:[39m[38;5;245m [39m[38;5;247mtorch[39m[38;5;245m.[39m[38;5;247mSize[39m[38;5;245m([39m[38;5;245m[[39m[38;5;36m2[39m[38;5;245m,[39m[38;5;245m [39m[38;5;36m512[39m[38;5;245m,[39m[38;5;245m [39m[38;5;36m16[39m[38;5;245m,[39m[38;5;245m [39m[38;5;36m16[39m[38;5;245m][39m[38;5;245m)[39m
[38;5;247mic[39m[38;5;245m|[39m[38;5;245m [39m[38;5;247mq[39m[38;5;245m.[39m[38;5;247mshape[39m[38;5;245m:[39m

## Decoder网络构成

In [57]:
class Decoder(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
                 attn_type="vanilla", **ignorekwargs):
        super().__init__()
        if use_linear_attn: attn_type = "linear"
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.give_pre_end = give_pre_end
        self.tanh_out = tanh_out

        # compute in_ch_mult, block_in and curr_res at lowest res
        in_ch_mult = (1,)+tuple(ch_mult)
        block_in = ch*ch_mult[self.num_resolutions-1]
        curr_res = resolution // 2**(self.num_resolutions-1)
        self.z_shape = (1,z_channels,curr_res,curr_res)
        print("Working with z of shape {} = {} dimensions.".format(
            self.z_shape, np.prod(self.z_shape)))

        # z to block_in
        self.conv_in = torch.nn.Conv2d(z_channels,
                                       block_in,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks+1):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            self.up.insert(0, up) # prepend to get consistent order

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        out_ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

    def forward(self, z):
        #assert z.shape[1:] == self.z_shape[1:]
        self.last_z_shape = z.shape

        # timestep embedding
        temb = None

        # z to block_in
        h = self.conv_in(z)

        # middle
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)

        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks+1):
                h = self.up[i_level].block[i_block](h, temb)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)
            if i_level != 0:
                h = self.up[i_level].upsample(h)

        # end
        if self.give_pre_end:
            return h

        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        if self.tanh_out:
            h = torch.tanh(h)
        return h

decoder=Decoder(**ddconfig)


Working with z of shape (1, 3, 16, 16) = 768 dimensions.
making attention of type 'vanilla' with 512 in_channels


## 向量量化器以及量化损失 $\mathcal{L}_{codebook}$

来自 `VectorQuantizer2.forward()`：

```python
if not self.legacy:
    loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
           torch.mean((z_q - z.detach()) ** 2)
else:
    loss = torch.mean((z_q.detach()-z)**2) + \
           self.beta * torch.mean((z_q - z.detach()) ** 2)
```

### VQ-VAE 的核心：
将 encoder 输出 ($z_e(x)$) 替换为最邻近的 codebook 向量 ($e_k$)。


$$\mathcal{L}_{codebook} =
| \text{sg}[z_e(x)] - e_k |_2^2 +
\beta | z_e(x) - \text{sg}[e_k] |_2^2$$

其中 `sg[·]` 表示 **stop-gradient**，防止梯度流到该项。
第一项更新 codebook embedding；
第二项更新 encoder，让其靠近最近的 embedding。

### 🧩 Trick: Straight-Through Estimator (STE)

```python
z_q = z + (z_q - z).detach()
```

这行是 **直通梯度估计**：

* 前向：输出 (z_q)
* 反向：梯度直接传回 (z)，绕过离散量化操作。


In [None]:

class VectorQuantizer2(nn.Module):
    """
    Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
    avoids costly matrix multiplications and allows for post-hoc remapping of indices.
    """
    # NOTE: due to a bug the beta term was applied to the wrong term. for
    # backwards compatibility we use the buggy version by default, but you can
    # specify legacy=False to fix it.
    def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
                 sane_index_shape=False, legacy=True):
        super().__init__()
        self.n_e = n_e
        self.e_dim = e_dim
        self.beta = beta
        self.legacy = legacy

        self.embedding = nn.Embedding(self.n_e, self.e_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

        self.remap = remap
        if self.remap is not None:
            self.register_buffer("used", torch.tensor(np.load(self.remap)))
            self.re_embed = self.used.shape[0]
            self.unknown_index = unknown_index # "random" or "extra" or integer
            if self.unknown_index == "extra":
                self.unknown_index = self.re_embed
                self.re_embed = self.re_embed+1
            print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
                  f"Using {self.unknown_index} for unknown indices.")
        else:
            self.re_embed = n_e

        self.sane_index_shape = sane_index_shape

    def remap_to_used(self, inds):
        ishape = inds.shape
        assert len(ishape)>1
        inds = inds.reshape(ishape[0],-1)
        used = self.used.to(inds)
        match = (inds[:,:,None]==used[None,None,...]).long()
        new = match.argmax(-1)
        unknown = match.sum(2)<1
        if self.unknown_index == "random":
            new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
        else:
            new[unknown] = self.unknown_index
        return new.reshape(ishape)

    def unmap_to_all(self, inds):
        ishape = inds.shape
        assert len(ishape)>1
        inds = inds.reshape(ishape[0],-1)
        used = self.used.to(inds)
        if self.re_embed > self.used.shape[0]: # extra token
            inds[inds>=self.used.shape[0]] = 0 # simply set to zero
        back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
        return back.reshape(ishape)

    def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
        assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
        assert rescale_logits==False, "Only for interface compatible with Gumbel"
        assert return_logits==False, "Only for interface compatible with Gumbel"
        # reshape z -> (batch, height, width, channel) and flatten
        z = rearrange(z, 'b c h w -> b h w c').contiguous()
        z_flattened = z.view(-1, self.e_dim)
        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z

        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight**2, dim=1) - 2 * \
            torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))

        min_encoding_indices = torch.argmin(d, dim=1)
        z_q = self.embedding(min_encoding_indices).view(z.shape)
        perplexity = None
        min_encodings = None

        # compute loss for embedding
        if not self.legacy:
            loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
                   torch.mean((z_q - z.detach()) ** 2)
        else:
            loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
                   torch.mean((z_q - z.detach()) ** 2)

        # preserve gradients
        z_q = z + (z_q - z).detach()

        # reshape back to match original input shape
        z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()

        if self.remap is not None:
            min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
            min_encoding_indices = self.remap_to_used(min_encoding_indices)
            min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten

        if self.sane_index_shape:
            min_encoding_indices = min_encoding_indices.reshape(
                z_q.shape[0], z_q.shape[2], z_q.shape[3])

        return z_q, loss, (perplexity, min_encodings, min_encoding_indices)

    def get_codebook_entry(self, indices, shape):
        # shape specifying (batch, height, width, channel)
        if self.remap is not None:
            indices = indices.reshape(shape[0],-1) # add batch axis
            indices = self.unmap_to_all(indices)
            indices = indices.reshape(-1) # flatten again

        # get quantized latent vectors
        z_q = self.embedding(indices)

        if shape is not None:
            z_q = z_q.view(shape)
            # reshape back to match original input shape
            z_q = z_q.permute(0, 3, 1, 2).contiguous()

        return z_q



## VQ-GAN的损失函数 (代码实现为VQLPIPSWithDiscriminator类)
这个类实现的正是 VQ-GAN 论文中的损失结构：
$$\mathcal{L}*{total} = \mathcal{L}*{rec/perceptual} + \lambda \mathcal{L}*{GAN} + \beta \mathcal{L}*{codebook}$$
并通过自适应权重 λ 平衡重建与对抗两部分的梯度强度。


其中：
  * **$\mathcal{L}_{rec}$**：重建损失（像素 + 感知损失）
  * **$\mathcal{L}_{GAN}$**：对抗损失（生成器部分）
  * **$\mathcal{L}_{codebook}$**：量化损失（codebook embedding 更新）
  * **$\lambda$**：自适应平衡权重（梯度范数比）
  * **$\beta$**：VQ 中控制 embedding 更新的比例因子

---

### 🧱 1. 模型结构组成

| 模块              | 含义                                  |
| --------------- | ----------------------------------- |
| `E` / `G` / `Z` | 分别是 encoder、decoder、codebook        |
| `D`             | 判别器（PatchGAN 类型）                    |
| `LPIPS`         | 感知损失（perceptual loss），使用预训练网络计算特征差异 |
| `codebook_loss` | 码本量化损失（commitment + embedding）      |
| `disc_loss`     | 对抗损失，可选 "hinge" 或 "vanilla"         |

---

### ⚙️ 2. 总体前向逻辑

在 `forward()` 中根据 `optimizer_idx` 分为两步：

1. `optimizer_idx == 0`：更新生成器（E, G, Z）
2. `optimizer_idx == 1`：更新判别器 D

---

### 🧩 3. 重建 + 感知损失（Reconstruction + Perceptual）

```python
rec_loss = torch.abs(inputs - reconstructions)
p_loss = self.perceptual_loss(inputs, reconstructions)
rec_loss = rec_loss + self.perceptual_weight * p_loss
nll_loss = torch.mean(rec_loss)
```

对应公式：

$$\mathcal{L}_{rec/perceptual} =
\mathbb{E}[|x - \hat{x}|_1] +
\lambda_p |\phi(x) - \phi(\hat{x})|_2^2$$

* `L1` 替代了像素级 L2，更鲁棒；
* `LPIPS` 提供感知层面的特征距离；
* `perceptual_weight` 控制感知损失的权重；
* `nll_loss` 相当于 Eq.(6) 中的 ($\mathcal{L}_{rec}$)。

---

### 🧮 4. Codebook 量化损失

在 VQ-VAE/VQ-GAN 中：

$$\mathcal{L}_{codebook} = |\text{sg}[E(x)] - z_q|^2 + |E(x) - \text{sg}[z_q]|^2$$

这里在总损失中体现为：

```python
loss += self.codebook_weight * codebook_loss.mean()
```

表示 “保持编码器与码本一致性” 的约束项。

---

### 🎭 5. 对抗损失（Adversarial Loss）

#### ➤ Generator（更新生成器时）
  ```python
  logits_fake = self.discriminator(reconstructions)
  g_loss = -torch.mean(logits_fake)
  ```
  * 对应 **hinge-GAN 生成器损失**：
    $$\mathcal{L}_G = -\mathbb{E}[D(\hat{x})]$$
  * 即希望生成图像的判别得分高（骗过 D）。

#### ➤ Discriminator（更新判别器时）

  ```python
  logits_real = self.discriminator(inputs.detach())
  logits_fake = self.discriminator(reconstructions.detach())
  d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
  ```

* 对应 **hinge-GAN 判别器损失**：
  $$\mathcal{L}_D =
  \mathbb{E}[\max(0, 1 - D(x))] + \mathbb{E}[\max(0, 1 + D(\hat{x}))]$$

* 若为 vanilla-GAN，则换为：
  $$\mathcal{L}_D = -[\log D(x) + \log(1 - D(\hat{x}))]$$

---

### ⚖️ 6. 自适应权重 λ（Adaptive Weight）

```python
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer)
```

公式对应 Eq.(7)：

$$\lambda = \frac{|\nabla_{G_L}[\mathcal{L}*{rec}]|}
{|\nabla*{G_L}[\mathcal{L}_{GAN}]| + \delta}$$

实现技巧：

* 通过 `torch.autograd.grad` 对最后一层梯度取范数；
* 避免除零，用 `+ 1e-4`；
* `torch.clamp()` 限制范围；
* 最后乘上 `self.discriminator_weight`（整体缩放因子）。

**意义：**

* 若 GAN 训练太强（梯度大）→ 减小 λ；
* 若重建信号强 → 增大 λ；
* 保证训练稳定、视觉质量与结构保真兼顾。

---

### 🧠 7. 判别器延迟启用技巧

```python
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
```

实现细节：

* 训练初期不启用判别器（防止 GAN 不稳定）；
* 当 `global_step > disc_start` 后再激活；
* 相当于论文中的 “warm-up” 策略。

---

### 🧾 8. 总损失表达式（Generator阶段）

$$\mathcal{L}*{total} =
\underbrace{\mathcal{L}*{rec/perceptual}}*{\text{重建+感知}}+
\underbrace{d*{weight} \cdot disc_{factor} \cdot \mathcal{L}*{GAN}}*{\text{对抗项}}
+\underbrace{\beta \cdot \mathcal{L}*{codebook}}*{\text{量化约束}}$$

---

### 📘 9. 日志项（log）输出说明

模型在训练时会记录：

| 名称            | 含义            |
| ------------- | ------------- |
| `total_loss`  | 总损失           |
| `quant_loss`  | 码本损失          |
| `rec_loss`    | 重建损失（像素 + 感知） |
| `p_loss`      | LPIPS感知部分     |
| `d_weight`    | 当前自适应权重       |
| `disc_factor` | 判别器启用权重       |
| `g_loss`      | 生成器GAN损失      |
| `disc_loss`   | 判别器损失         |

---

### ✅ 10. 小结：与论文对应关系

| 代码项                         | 论文公式                         | 作用        |
| --------------------------- | ---------------------------- | --------- |
| `nll_loss`                  | (\mathcal{L}_{rec})          | 重建（像素/感知） |
| `codebook_loss`             | (\mathcal{L}_{VQ})           | 量化一致性     |
| `g_loss`                    | (\mathcal{L}_{GAN}(E,G,Z,D)) | 生成器对抗项    |
| `d_loss`                    | (\mathcal{L}_{D})            | 判别器更新     |
| `calculate_adaptive_weight` | Eq.(7)                       | 自适应平衡 λ   |
| `adopt_weight`              | 延迟判别器启用                      | 训练稳定性技巧   |


In [None]:
class VQLPIPSWithDiscriminator(nn.Module):
    def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
                 disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
                 perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
                 disc_ndf=64, disc_loss="hinge"):
        super().__init__()
        assert disc_loss in ["hinge", "vanilla"]
        self.codebook_weight = codebook_weight
        self.pixel_weight = pixelloss_weight
        self.perceptual_loss = LPIPS().eval()
        self.perceptual_weight = perceptual_weight

        self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
                                                 n_layers=disc_num_layers,
                                                 use_actnorm=use_actnorm,
                                                 ndf=disc_ndf
                                                 ).apply(weights_init)
        self.discriminator_iter_start = disc_start
        if disc_loss == "hinge":
            self.disc_loss = hinge_d_loss
        elif disc_loss == "vanilla":
            self.disc_loss = vanilla_d_loss
        else:
            raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
        print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
        self.disc_factor = disc_factor
        self.discriminator_weight = disc_weight
        self.disc_conditional = disc_conditional

    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
        if last_layer is not None:
            nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
        else:
            nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]

        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
        d_weight = d_weight * self.discriminator_weight
        return d_weight

    def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
                global_step, last_layer=None, cond=None, split="train"):
        rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
        if self.perceptual_weight > 0:
            p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
            rec_loss = rec_loss + self.perceptual_weight * p_loss
        else:
            p_loss = torch.tensor([0.0])

        nll_loss = rec_loss
        #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
        nll_loss = torch.mean(nll_loss)

        # now the GAN part
        if optimizer_idx == 0:
            # generator update
            if cond is None:
                assert not self.disc_conditional
                logits_fake = self.discriminator(reconstructions.contiguous())
            else:
                assert self.disc_conditional
                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
            g_loss = -torch.mean(logits_fake)

            try:
                d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
            except RuntimeError:
                assert not self.training
                d_weight = torch.tensor(0.0)

            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
            loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()

            log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
                   "{}/quant_loss".format(split): codebook_loss.detach().mean(),
                   "{}/nll_loss".format(split): nll_loss.detach().mean(),
                   "{}/rec_loss".format(split): rec_loss.detach().mean(),
                   "{}/p_loss".format(split): p_loss.detach().mean(),
                   "{}/d_weight".format(split): d_weight.detach(),
                   "{}/disc_factor".format(split): torch.tensor(disc_factor),
                   "{}/g_loss".format(split): g_loss.detach().mean(),
                   }
            return loss, log

        if optimizer_idx == 1:
            # second pass for discriminator update
            if cond is None:
                logits_real = self.discriminator(inputs.contiguous().detach())
                logits_fake = self.discriminator(reconstructions.contiguous().detach())
            else:
                logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))

            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
            d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)

            log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
                   "{}/logits_real".format(split): logits_real.detach().mean(),
                   "{}/logits_fake".format(split): logits_fake.detach().mean()
                   }
            return d_loss, log
