In [20]:
import numpy as np
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as td
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms, datasets
from sklearn.model_selection import train_test_split
from netVAE import VAE

In [3]:
data = np.load("data4D.npy", allow_pickle= True)
data.shape

(5000, 3, 32, 32)

In [4]:
# Trainset/testset
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

# 将NumPy数组转换为张量
train_data_tensor = torch.from_numpy(train_data)
test_data_tensor = torch.from_numpy(test_data)

# 创建DataLoader
batch_size = 128
train_dataset = TensorDataset(train_data_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = TensorDataset(test_data_tensor)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
class VAE(nn.Module):
    def __init__(self, img_channels, latent_dim):
        super(VAE, self).__init__()

        # 编码器
        self.encoder = nn.Sequential(
            nn.Conv2d(img_channels, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )

        # 获取编码器的输出大小
        self.encoder_output_size = self.encoder(torch.zeros(1, img_channels, 32, 32)).shape[1]

        # 计算均值和对数方差
        self.fc_mu = nn.Linear(self.encoder_output_size, latent_dim)
        self.fc_logvar = nn.Linear(self.encoder_output_size, latent_dim)

        # 解码器
        self.fc_decode = nn.Linear(latent_dim, self.encoder_output_size)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, img_channels, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.encoder(x)
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(self.fc_decode(z))
        return x_recon, mu, logvar

In [6]:
def vae_loss(x, x_recon, mu, logvar):
    recon_loss = nn.functional.mse_loss(x_recon, x, reduction='sum') / x.size(0)
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
    return recon_loss + kl_div

In [7]:
img_channels = 3
latent_dim = 128
model = VAE(img_channels, latent_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [8]:
num_epochs = 50

# TRAIN
model.train()

for epoch in range(num_epochs):
    train_loss = 0.0
    for batch_idx, (data,) in enumerate(train_loader):
        data = data.float()
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = vae_loss(data, recon_batch, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    train_loss /= len(train_loader.dataset)
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch + 1, train_loss))

====> Epoch: 1 Average loss: 2.9654
====> Epoch: 2 Average loss: 1.3382
====> Epoch: 3 Average loss: 1.3306
====> Epoch: 4 Average loss: 1.3334
====> Epoch: 5 Average loss: 1.3313
====> Epoch: 6 Average loss: 1.3217
====> Epoch: 7 Average loss: 1.3270
====> Epoch: 8 Average loss: 1.3214
====> Epoch: 9 Average loss: 1.3255
====> Epoch: 10 Average loss: 1.3325
====> Epoch: 11 Average loss: 1.3190
====> Epoch: 12 Average loss: 1.3236
====> Epoch: 13 Average loss: 1.3178
====> Epoch: 14 Average loss: 1.3159
====> Epoch: 15 Average loss: 1.3084
====> Epoch: 16 Average loss: 1.3095
====> Epoch: 17 Average loss: 1.3016
====> Epoch: 18 Average loss: 1.3032
====> Epoch: 19 Average loss: 1.2965
====> Epoch: 20 Average loss: 1.2933
====> Epoch: 21 Average loss: 1.2921
====> Epoch: 22 Average loss: 1.2869
====> Epoch: 23 Average loss: 1.2892
====> Epoch: 24 Average loss: 1.2942
====> Epoch: 25 Average loss: 1.2920
====> Epoch: 26 Average loss: 1.2916
====> Epoch: 27 Average loss: 1.2897
====> Epoc

In [22]:
model.eval()
test_loss = 0.0

with torch.no_grad():
    for i, (data,) in enumerate(test_loader):
        data = data.float()
        recon_batch, mu, logvar = model(data)
        test_loss += vae_loss(data, recon_batch, mu, logvar).item()

test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))

====> Test set loss: 1.2608


In [17]:
reconstructions = []
originals = []
model.eval()
test_loss = 0.0
with torch.no_grad():
    for i, (data,) in enumerate(test_loader):
        data = data.float()
        originals.append(data.numpy())
        recon_batch, mu, logvar = model(data)
        reconstructions.append(recon_batch.numpy())
        test_loss += vae_loss(data, recon_batch, mu, logvar).item()

test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))

====> Test set loss: 1.2609


In [19]:
len(reconstructions)

8

In [12]:
reconstructions = np.concatenate(reconstructions, axis=0)
originals = np.concatenate(originals, axis=0)

np.save('reconstructions.npy', reconstructions)
np.save('originals.npy', originals)

In [13]:
reconstructions.shape

(3000, 32, 32)

In [15]:
originals.shape

(3000, 32, 32)

In [16]:
test_data.shape

(1000, 3, 32, 32)

In [21]:
reconstructions = []
originals = []

with torch.no_grad():
    for i, (data,) in enumerate(test_loader):
        data = data.float()
        originals.append(data.cpu().numpy())
        recon_batch, mu, logvar = model(data)
        reconstructions.append(recon_batch.cpu().numpy())
        test_loss += vae_loss(data, recon_batch, mu, logvar).item()
test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))

====> Test set loss: 1.2621


In [23]:
retest = np.concatenate(reconstructions, axis=0)
test = np.concatenate(originals, axis=0)

np.save('retest.npy', retest)
np.save('test.npy', test)

In [26]:
torch.save(model.state_dict(), 'vae_model.pth')
