In [1]:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset 
import os
import numpy as np
from PIL import Image

os.chdir("D:\\code\\VAE")
class StressDataset(Dataset):
    def __init__(self, stress_dir, transform=None):
        self.stress_dir = stress_dir
        self.stress_files = [os.path.join(stress_dir, f) for f in os.listdir(stress_dir) if f.endswith('.png')]
        self.transform = transform

    def __len__(self):
        return len(self.stress_files)

    def __getitem__(self, idx):
        img_path = self.stress_files[idx]
        stress = Image.open(img_path).convert('RGB')
        stress1 = np.array(stress).astype(np.float32).transpose(2, 0, 1)/255.0
        
        if self.transform:
            stress = self.transform(stress)
        return stress,stress1

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

traindataset = StressDataset(
    stress_dir=r"D:\\cosegGuitar\\256x256\\stress",
    transform=transform
)

dataloader = DataLoader(traindataset, batch_size=8, shuffle=True, num_workers=0)

In [2]:
import torch.nn.functional as F
def loss_function(recon_x, x, mu, log_var):
    BLC = F.binary_cross_entropy(recon_x, x, reduction='mean')
    KLD = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
    return BLC + 0.001 * KLD

In [3]:
from model import VAE
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

# 初始化模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE().to(device)
model.load_state_dict(torch.load('vae_model_state_dict.pth'))

# 定义优化器和学习率调度器
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size=1, gamma=0.95)
print(f"initial lr: {scheduler.get_last_lr()[0]}")

# 训练循环
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch_idx, (stress, stress1) in enumerate(dataloader):
        stress = stress.to(device)
        stress1 = stress1.to(device)
        optimizer.zero_grad()
        
        # 前向传播
        recon_stress, mu, log_var = model(stress)

        # 计算损失
        loss = loss_function(recon_stress, stress1, mu, log_var)
     
        # 反向传播与优化
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
       
        if batch_idx % 10 == 0:
            print(f" batchidx [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}")
            if loss.item() < 0.03:
                torch.save(model.state_dict(), 'vae_model_state_dict.pth')
                print('model saved!')
                break
    scheduler.step()
    print("recon_x range:", recon_stress.min().item(), recon_stress.max().item())
    print("x range:", stress1.min().item(), stress1.max().item())
    print("log_var range:", log_var.min().item(), log_var.max().item())
    print(f"== Epoch [{epoch+1}/{num_epochs}], avrage Loss: {train_loss / len(dataloader):.4f} ==")
    if train_loss / len(dataloader) < 0.03:
       torch.save(model.state_dict(), 'vae_model_state_dict.pth')
       print('model saved!')

initial lr: 0.001
 batchidx [0/1149], Loss: 0.0357
 batchidx [10/1149], Loss: 0.0380
 batchidx [20/1149], Loss: 0.0431
 batchidx [30/1149], Loss: 0.0392
 batchidx [40/1149], Loss: 0.0326
 batchidx [50/1149], Loss: 0.0358
 batchidx [60/1149], Loss: 0.0291
model saved!
recon_x range: 2.780578789440824e-08 1.0
x range: 0.0 1.0
log_var range: 6.203053635545075e-05 0.016567014157772064
== Epoch [1/10], avrage Loss: 0.0019 ==
model saved!


KeyboardInterrupt: 