In [109]:
import os
import math
import matplotlib as plt
from tqdm import trange
from PIL import Image
from loguru import logger

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

In [104]:
train_batch_size = 32
test_batch_size = 32
epochs = 100
input_dim = 784
latent_dim = 64
output_dim = 784
lr = 1e-5
eval_num_images = 1
save_dir = "./generate/vae/"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [114]:
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim, output_dim):
        super().__init__()
        # encoder
        self.encoder_mean = nn.Linear(input_dim, latent_dim)
        # 方差必须是非负数，但是MLP的输出有正有负，直接输出方差会导致计算KL损失时出现nan（ln(-1)）
        # 因此我们在取对数方差，输出ln(sigma)，ln可正可负，且保证了sigma为非负数
        # self.encoder_std = nn.Linear(input_dim, latent_dim)
        self.encoder_ln_std = nn.Linear(input_dim, latent_dim)
        self.decoder = nn.Linear(latent_dim, output_dim)

    def forward(self, x):
        mean, ln_std = self.encoder_mean(x), self.encoder_ln_std(x)
        epsilon = torch.randn_like(mean)
        z = mean + torch.exp(ln_std) * epsilon
        x = self.decoder(z)
        return x, mean, ln_std

    def calculate_loss(self, x):
        x, mean, ln_std = self.forward(x)
        mse_loss = nn.MSELoss()(logit, feature)
        # kl = 0.5 * (mean ^ 2 + std ^ 2 - 2ln(std) - 1)
        kl_loss = torch.mean(0.5 * torch.sum(mean ** 2 + torch.exp(ln_std) ** 2 - 2 * ln_std - 1))
        loss = mse_loss + kl_loss
        return loss

In [106]:
# Dataset
transform = transforms.ToTensor()
fashion_mnist_train = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
fashion_mnist_test = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
print(f"Train: {len(fashion_mnist_train)}; Test: {len(fashion_mnist_test)})")

Train: 60000; Test: 10000)


In [107]:
# Loader
train_loader = DataLoader(fashion_mnist_train, batch_size=train_batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(fashion_mnist_test, batch_size=test_batch_size, shuffle=True, drop_last=True)

In [115]:
# Create VAE
model = VAE(input_dim, latent_dim, output_dim)
optim = torch.optim.AdamW(model.parameters(), lr=lr)
# Train
for epoch in trange(epochs):
    train_loss = 0.0
    test_loss = 0.0
    # train
    model.train()
    for idx, (feature, _) in enumerate(train_loader):
        optim.zero_grad()
        feature = feature.reshape(train_batch_size, -1)
        # loss
        loss = model.calculate_loss(feature)
        if torch.isnan(loss).any():
            logger.error("Loss is NaN.")
        loss.backward()
        optim.step()
        train_loss += loss.item()
    # 打印每个 epoch 的平均损失
    avg_loss = train_loss / len(train_loader.dataset)
    print(f"Epoch {epoch + 1}, Train Loss: {avg_loss:.4f}")
    # evaluation
    model.eval()
    with torch.no_grad():
        for idx, (feature, _) in enumerate(test_loader):
            feature = feature.reshape(test_batch_size, -1)
            loss = model.calculate_loss(feature)
            test_loss += loss.item()
        avg_test_loss = test_loss / len(test_loader.dataset)
        print(f"Epoch {epoch + 1}, Test Loss: {avg_test_loss: .4f}")
    # generate image
    with torch.no_grad():
        z = torch.randn(eval_num_images, latent_dim)
        images = model.decoder(z).cpu().reshape(-1, 1, 28, 28)
    # save image
    for i in range(eval_num_images):
        # 获取单张图像，去掉多余的维度
        img = images[i, 0].cpu().numpy()
        # 将 numpy 数组转换为 PIL 图像
        pil_img = Image.fromarray((img * 255).astype('uint8'))  # 转换为 [0, 255] 范围
        img_path = os.path.join(save_dir, f"{epoch}_{i + 1}.png")
        # 保存图像
        pil_img.save(img_path)


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1, Train Loss: 1.1007


 10%|█         | 1/10 [00:03<00:35,  3.94s/it]

Epoch 1, Test Loss:  0.5750
Epoch 2, Train Loss: 0.4626


 20%|██        | 2/10 [00:07<00:29,  3.66s/it]

Epoch 2, Test Loss:  0.3760
Epoch 3, Train Loss: 0.3191


 30%|███       | 3/10 [00:10<00:24,  3.54s/it]

Epoch 3, Test Loss:  0.2729
Epoch 4, Train Loss: 0.2386


 40%|████      | 4/10 [00:14<00:20,  3.48s/it]

Epoch 4, Test Loss:  0.2104
Epoch 5, Train Loss: 0.1872


 50%|█████     | 5/10 [00:17<00:17,  3.43s/it]

Epoch 5, Test Loss:  0.1684
Epoch 6, Train Loss: 0.1514


 60%|██████    | 6/10 [00:20<00:13,  3.44s/it]

Epoch 6, Test Loss:  0.1379
Epoch 7, Train Loss: 0.1249


 70%|███████   | 7/10 [00:24<00:10,  3.47s/it]

Epoch 7, Test Loss:  0.1150
Epoch 8, Train Loss: 0.1046


 80%|████████  | 8/10 [00:27<00:06,  3.46s/it]

Epoch 8, Test Loss:  0.0971
Epoch 9, Train Loss: 0.0886


 90%|█████████ | 9/10 [00:31<00:03,  3.45s/it]

Epoch 9, Test Loss:  0.0828
Epoch 10, Train Loss: 0.0757


100%|██████████| 10/10 [00:34<00:00,  3.46s/it]

Epoch 10, Test Loss:  0.0710



