In [5]:
import torch.nn as nn
import torch
import torch.functional as F

class Encoder1(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(1000, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        self.softplus = nn.Softplus()

    def forward(self, x):
        hidden = self.softplus(self.fc1(x))
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))
        return z_loc, z_scale

class Decoder1(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, 1000)
        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid()

    def forward(self, z):
        hidden = self.softplus(self.fc1(z))
        loc_img = self.sigmoid(self.fc21(hidden))
        return loc_img


In [6]:
class Encoder2(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        self.softplus = nn.Softplus()

    def forward(self, z):
        hidden = self.softplus(self.fc1(z))
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))
        return z_loc, z_scale

class Decoder2(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, 1000)
        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid()

    def forward(self, z):
        hidden = self.softplus(self.fc1(z))
        loc_img = self.sigmoid(self.fc21(hidden))
        return loc_img


In [10]:
class VAE(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        self.encoder1 = Encoder1(z_dim, hidden_dim)
        self.decoder1 = Decoder1(z_dim, hidden_dim)
        self.encoder2 = Encoder2(z_dim, hidden_dim)
        self.decoder2 = Decoder2(z_dim, hidden_dim)

    def reparameterize(self, z_loc, z_scale):
        epsilon = torch.randn_like(z_loc)
        z = z_loc + epsilon * z_scale
        return z

    def forward(self, x):
        z_loc1, z_scale1 = self.encoder1(x)
        z1 = self.reparameterize(z_loc1, z_scale1)
        x_recon1 = self.decoder1(z1)

        z_loc2, z_scale2 = self.encoder2(x_recon1)
        z2 = self.reparameterize(z_loc2, z_scale2)
        x_recon2 = self.decoder2(z2)

        return x_recon2, z_loc1, z_scale1, z_loc2, z_scale2

    def compute_loss(self, x, x_recon2, z_loc1, z_scale1, z_loc2, z_scale2):
        # Reconstruction loss
        recon_loss = F.binary_cross_entropy(x_recon2, x.view(1, 1000), reduction='sum')

        # KL divergence loss
        kl_loss1 = -0.5 * torch.sum(1 + torch.log(z_scale1.pow(2)) - z_loc1.pow(2) - z_scale1.pow(2))
        kl_loss2 = -0.5 * torch.sum(1 + torch.log(z_scale2.pow(2)) - z_loc2.pow(2) - z_scale2.pow(2))

        # Total loss
        loss = recon_loss + kl_loss1 + kl_loss2

        return loss


In [12]:
# 定义超参数
z_dim = 20
hidden_dim = 256
# 随机生成数据
data = torch.randn(1000, 1)
data = data.reshape((1, 1000))


# 实例化VAE模型
vae = VAE(z_dim, hidden_dim)

# 定义优化器
optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)

num_epochs = 2000
# 训练循环
for epoch in range(num_epochs):
    optimizer.zero_grad()
    x_recon2, z_loc1, z_scale1, z_loc2, z_scale2 = vae.forward(data)
    loss = vae.compute_loss(data, x_recon2, z_loc1, z_scale1, z_loc2, z_scale2)
    loss.backward()
    optimizer.step()


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x1000 and 20x256)