In [None]:
import numpy as np
import wandb
import yaml
import shutil
from yaml.loader import SafeLoader
from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as td
import torchvision
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms, datasets
from sklearn.model_selection import train_test_split
# from netAE import AE
# from netVAE import VAE

In [None]:
path_yaml  = "config.yaml"

#config = yaml.load(path_yaml,Loader=SafeLoader)

In [None]:
data = np.load("data4D.npy", allow_pickle= True)
data.shape

In [None]:
# Trainset/testset
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

# 将NumPy数组转换为张量
train_data_tensor = torch.from_numpy(train_data)
test_data_tensor = torch.from_numpy(test_data)

# 创建DataLoader
batch_size = 128
train_dataset = TensorDataset(train_data_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = TensorDataset(test_data_tensor)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
class VAE(nn.Module):
    def __init__(self, img_channels, latent_dim):
        super(VAE, self).__init__()
        last_dim = 128
        # 编码器
        self.encoder = nn.Sequential(
            nn.Conv2d(img_channels, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, last_dim, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )

        # 计算均值和对数方差
        self.fc_mu = nn.Linear(16 * last_dim, latent_dim)
        self.fc_logvar = nn.Linear(16 * last_dim, latent_dim)

        # 解码器
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 16 * last_dim),
            nn.Unflatten(1, (last_dim, 4, 4)),
            nn.ConvTranspose2d(last_dim, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, img_channels, kernel_size=4, stride=2, padding=1),
            # nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.encoder(x)
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        # z = self.reparameterize(mu, logvar)
        std = logvar.exp().pow(0.5)
        q_z = td.normal.Normal(mu, std)
        z = q_z.rsample()
        x_recon = self.decoder(z)
        
        return x_recon, q_z

In [None]:
def vae_loss(x, x_recon, q_z):
    recon_loss = nn.functional.mse_loss(x_recon, x, reduction='sum')
    # kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
    p_z = td.normal.Normal(torch.zeros_like(q_z.loc), torch.ones_like(q_z.scale))
    kl_div = td.kl_divergence(q_z, p_z).sum()
    return recon_loss + kl_div

In [None]:
img_channels = 3
latent_dim = 128
model = VAE(img_channels, latent_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
num_epochs = 50

# TRAIN
model.train()
for epoch in range(num_epochs):
    train_loss = []
    for batch_idx, (data,) in enumerate(train_loader):
        data = data.float()
        optimizer.zero_grad()
        recon_batch, q_z = model(data)
        loss = vae_loss(data, recon_batch, q_z)
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())

    print(f'====> Epoch: {epoch + 1:02d} | Average loss: {sum(train_loss) / len(train_loader.dataset):.4f}')

In [None]:
model.eval()
test_loss = 0.0

with torch.no_grad():
    for i, (data,) in enumerate(test_loader):
        data = data.float()
        recon_batch, q_z = model(data)
        test_loss += vae_loss(data, recon_batch, q_z).item()

test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))