导入环境

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

In [4]:
# 定义VAE
class VAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super(VAE, self).__init__()
        # 编码器
        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, latent_size*2)
        )
        # 解码器
        self.decoder = nn.Sequential(
            nn.Linear(latent_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, input_size),
            nn.Sigmoid()
        )
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.rand_like(std)
        return mu + eps*std
    
    def forward(self, x):
        # 编码
        mu_logvar = self.encoder(x)
        mu, logvar = torch.chunk(mu_logvar, 2, dim=1)
        # 重参数化
        z = self.reparameterize(mu, logvar)
        # 解码
        x_hat = self.decoder(z)
        return x_hat, mu, logvar
    
# 定义训练函数
def train_vae(model, train_loader, num_epochs, learning_rate):
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()

    for epoch in range(num_epochs):
        total_loss = 0.0

        for data in train_loader:
            images, _ = data
            images = images.view(images.size(0), -1) # 将图片展开成 784 维向量
            optimizer.zero_grad()  # 梯度清零

            # 前向传播
            outputs, mu, logvar = model(images)

            # 计算损失
            reconstruction_loss = criterion(outputs, images)
            kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

            # 计算总损失
            total_loss = reconstruction_loss + kl_divergence

            # 反向传播
            total_loss.backward()
            optimizer.step()

            total_loss += total_loss.item()
        
        # 输出损失
        print("Epoch: {}, Loss: {:.4f}".format(epoch + 1, total_loss/len(train_loader)))

    print('Train finished')

Linear: 对与输入数据X执行操作$y = xA^T + b$其中A是模块的权重，b是模块的偏置，这两个值是在模块初始化中随机生成的

In [5]:
# 设置超参数
input_size = 784
hidden_size = 256
latent_size = 64
num_epochs = 10
learning_rate = 1e-3

# 加载MNISt数据集
from torchvision import datasets, transforms
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)

# 创建模型
model = VAE(input_size, hidden_size, latent_size)

# 训练模型
train_vae(model, train_loader, num_epochs, learning_rate)

Epoch: 1, Loss: 0.0011
Epoch: 2, Loss: 0.0011
Epoch: 3, Loss: 0.0011
Epoch: 4, Loss: 0.0011
Epoch: 5, Loss: 0.0011
Epoch: 6, Loss: 0.0011
Epoch: 7, Loss: 0.0011
Epoch: 8, Loss: 0.0011
Epoch: 9, Loss: 0.0011
Epoch: 10, Loss: 0.0012
Train finished
