In [3]:
from UNet import MaskedAE
import torch
import torch.nn.functional as F

In [4]:
# 初始化
model = MaskedAE(mask_type="block")
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
# 模拟数据加载
train_loader = torch.utils.data.DataLoader(
    torch.randn(1000, 22, 64),  # 模拟1000个样本
    batch_size=8,
    shuffle=True
)

In [5]:
# 训练循环
for epoch in range(100):
    model.train()
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        # 动态调整掩码比例 (20% -> 60%)
        curr_mask_ratio = 0.2 + 0.4 * (epoch / 100)
        outputs = model(batch, mask_ratio=curr_mask_ratio)
        loss = outputs['loss']
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    scheduler.step()
    # 验证重建效果
    if epoch % 10 == 0:
        model.eval()
        with torch.no_grad():
            test_sample = next(iter(train_loader))
            out = model(test_sample)
            # 计算可见区域的PSNR
            mse = F.mse_loss(out['reconstructed'], test_sample)
            psnr = -10 * torch.log10(mse)
            print(f"Epoch {epoch} | Loss: {total_loss/len(train_loader):.4f} | "
                  f"PSNR: {psnr:.2f} dB")

Epoch 0 | Loss: 0.8209 | PSNR: 0.06 dB
Epoch 10 | Loss: 0.6134 | PSNR: 0.69 dB
Epoch 20 | Loss: 0.5645 | PSNR: 0.78 dB
Epoch 30 | Loss: 0.5149 | PSNR: 0.74 dB
Epoch 40 | Loss: 0.4712 | PSNR: 0.84 dB
Epoch 50 | Loss: 0.4441 | PSNR: 0.94 dB
Epoch 60 | Loss: 0.4061 | PSNR: 0.96 dB
Epoch 70 | Loss: 0.3802 | PSNR: 1.05 dB
Epoch 80 | Loss: 0.3457 | PSNR: 0.91 dB
Epoch 90 | Loss: 0.3225 | PSNR: 0.85 dB
