## 初始化环境, 固定随机种子

In [1]:
import os
import pathlib

import torch
import torchvision as tv

import utils
from transformer_flow import Model


utils.set_random_seed(100)
notebook_output_path = pathlib.Path('runs/notebook')
"""输出路径"""

'输出路径'

## 设置超参数

In [2]:
# region 数据集元信息
dataset = 'mnist'
"""数据集"""
num_classes = 10
"""种类数"""
img_size = 28
"""图像大小"""
channel_size = 1
"""图像通道数"""
# endregion

# region 训练超参数
# we use a small model for fast demonstration, increase the model size for better results
patch_size = 4
"""图片块大小"""
channels = 128
"""隐藏层通道数"""
blocks = 4
"""MetaBlock数目"""
layers_per_block = 4
"""每个MetaBlock的层数"""
# try different noise levels to see its effect
noise_std = 0.1
"""噪声水平"""

batch_size = 256
"""批量大小"""
lr = 2e-3
"""学习率"""
# increase epochs for better results
epochs = 100
"""训练步长"""
sample_freq = 10
"""采样频率"""
# endregion

# region 设备选择
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'  # if on mac
else:
    device = 'cpu'  # if mps not available
print(f'using device {device}')
# endregion

# region 采样测试集
fixed_noise = torch.randn(num_classes * 10, (img_size // patch_size)**2, channel_size * patch_size ** 2, device=device)
"""固定噪声, 用于检验生成效果.

    shape: (B, C*P*P), 其中 B = num_classes * 10
"""
fixed_y = torch.arange(num_classes, device=device).view(-1, 1).repeat(1, 10).flatten()
"""固定目标标签, 用于检验生成效果.

    shape: (B), 其中 B = num_classes * 10
"""
# endregion


transform = tv.transforms.Compose([
    tv.transforms.Resize((img_size, img_size)),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.5,), (0.5,))
])
"""图像预处理管道"""
data = tv.datasets.MNIST('.', transform=transform, train=True, download=True)
"""MNIST数据集"""
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True)
"""数据加载器"""
model = Model(in_channels=channel_size, img_size=img_size, patch_size=patch_size,
              channels=channels, num_blocks=blocks, layers_per_block=layers_per_block,
              num_classes=num_classes).to(device)
"""TarFlow模型"""
optimizer = torch.optim.AdamW(model.parameters(), betas=(0.9, 0.95), lr=lr, weight_decay=1e-4)
"""优化器"""
lr_schedule = utils.CosineLRSchedule(optimizer, len(data_loader), epochs * len(data_loader), 1e-6, lr)
"""学习率迭代器"""
model_name = f'{patch_size}_{channels}_{blocks}_{layers_per_block}_{noise_std:.2f}'
"""模型名称"""
sample_dir = notebook_output_path / f'{dataset}_samples_{model_name}'
"""采样输出路径"""
ckpt_file = notebook_output_path / f'{dataset}_model_{model_name}.pth'
"""模型检查点保存路径"""
sample_dir.mkdir(exist_ok=True, parents=True)

using device cuda


## 训练过程

In [3]:
for epoch in range(epochs):

    # 初始化累计损失
    losses = 0

    for x, y in data_loader:  # x.shape = (B, C, H, W), y.shape = (B)

        x: torch.Tensor = x.to(device)
        # 加噪声
        eps = noise_std * torch.randn_like(x)
        x = x + eps
        y: torch.Tensor = y.to(device)

        # 重置梯度
        optimizer.zero_grad()

        # 前向传播
        z, outputs, logdets = model(x, y)

        # 计算损失
        loss = model.get_loss(z, logdets)

        # 计算梯度
        loss.backward()

        # 参数优化
        optimizer.step()

        # 更新学习率
        lr_schedule.step()

        # 统计损失
        losses += loss.item()

    # 打印训练信息
    print(f"epoch {epoch} lr {optimizer.param_groups[0]['lr']:.6f} loss {losses / len(data_loader):.4f}")
    print('layer norms', ' '.join([f'{z.pow(2).mean():.4f}' for z in outputs]))

    # 定期逆向传播, 检查模型效果
    if (epoch + 1) % sample_freq == 0:
        with torch.no_grad():
            samples = model.reverse(fixed_noise, fixed_y)
        tv.utils.save_image(samples, sample_dir / f'samples_{epoch:03d}.png', normalize=True, nrow=10)
        tv.utils.save_image(model.unpatchify(z[:100]), sample_dir / f'latent_{epoch:03d}.png', normalize=True, nrow=10)
        print('sampling complete')

epoch 0 lr 0.002000 loss -0.9788
layer norms 2.3528 2.1979 1.4153 0.9509
epoch 1 lr 0.001999 loss -1.3905
layer norms 1.7008 1.8958 1.3912 0.9854
epoch 2 lr 0.001998 loss -1.4862
layer norms 1.4503 1.5879 1.1261 0.9480
epoch 3 lr 0.001995 loss -1.5293
layer norms 1.3733 1.4909 1.0987 0.9519
epoch 4 lr 0.001992 loss -1.5488
layer norms 1.4172 1.5663 1.1787 1.0115
epoch 5 lr 0.001987 loss -1.5611
layer norms 1.4108 1.5871 1.2571 0.9760
epoch 6 lr 0.001982 loss -1.5695
layer norms 1.4366 1.6714 1.3719 1.0523
epoch 7 lr 0.001975 loss -1.5720
layer norms 1.4538 1.7026 1.3861 1.0639
epoch 8 lr 0.001968 loss -1.5795
layer norms 1.4884 1.8108 1.5291 0.9818
epoch 9 lr 0.001960 loss -1.5806
layer norms 1.5015 1.8684 1.6006 0.9805
sampling complete
epoch 10 lr 0.001950 loss -1.5774
layer norms 1.6227 2.0732 1.7008 0.9707
epoch 11 lr 0.001940 loss -1.5878
layer norms 1.5969 2.0625 1.7733 0.9937
epoch 12 lr 0.001928 loss -1.5892
layer norms 1.6248 2.1684 1.9259 1.0145
epoch 13 lr 0.001916 loss -1.5

## 保存模型

In [4]:
torch.save(model.state_dict(), ckpt_file)

## 检查模型

In [5]:
# now we can also evaluate the model by turning it into a classifier with Bayes rule, p(y|x) = p(y)p(x|y)/p(x)
data = tv.datasets.MNIST('.', transform=transform, train=False, download=False)
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=False)
num_correct = 0
num_examples = 0
for x, y in data_loader:  # x.shape = (B, C, W, W), y.shape = (B)

    x: torch.Tensor = x.to(device)
    y: torch.Tensor = y.to(device)
    eps = noise_std * torch.randn_like(x)

    x = x.repeat(num_classes, 1, 1, 1)  # (num_classes * B, C, W, W)
    y_ = torch.arange(num_classes, device=device).view(-1, 1).repeat(1, y.size(0)).flatten()
    # arange  -> (n)
    # view    -> (n,1)
    # repeat  -> (n,B)
    # flatten -> (n*B)

    with torch.no_grad():
        z, outputs, logdets = model(x, y_)
        losses = 0.5 * z.pow(2).mean(dim=[1, 2]) - logdets  # keep the batch dimension
        pred = losses.reshape(num_classes, y.size(0)).argmin(dim=0)

    num_correct += (pred == y).sum()
    num_examples += y.size(0)

print(f'Accuracy %{100 * num_correct / num_examples:.2f}')
torch.cuda.empty_cache()

Accuracy %98.36
