In [1]:
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np

In [2]:
batch_size = 128
seed = 1
epochs = 200
cuda = True
log_interval = 10
h_d = 512
l_d = 32
u_d = 1



torch.manual_seed(seed)

<torch._C.Generator at 0x14955fc6c10>

In [5]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        self.fc1 = nn.Linear(784, h_d)
        self.fc2 = nn.Linear(h_d, 128)
        self.fc21 = nn.Linear(128, l_d)
        self.fc22 = nn.Linear(128, l_d)
        
        # 转移层
        input_dim = l_d + u_d
        self.rnn_mu = nn.RNN(input_size=input_dim, hidden_size=l_d, batch_first=True)
        self.rnn_sigma = nn.RNN(input_size=input_dim, hidden_size=l_d, batch_first=True)
        
        self.fc3 = nn.Linear(l_d, 128)
        self.fc4 = nn.Linear(128, h_d)
        self.fc5 = nn.Linear(h_d, 784)
    
    def decode(self, x):
        x = x.float()
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        return self.fc21(h2), self.fc22(h2)
    
    def reparameterize1(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def transition(self, z, u):
        z = z.float()
        u = u.float()
        rnn_input = torch.cat((z,u),dim=2)
        mu2, _ = self.rnn_mu(rnn_input)
        logvar2, _ = self.rnn_sigma(rnn_input)
        return mu2, logvar2
    
    def reparameterize2(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, z):
        z = z.float()
        h3 = F.relu(self.fc3(z))
        h4 = F.relu(self.fc4(h3))
        return torch.sigmoid(self.fc5(h4))
    
    def forward(self, x, u):
        mu1, logvar1 = self.encode(x.view(-1, 784))
        z1 = self.reparameterize1(mu1, logvar1)
        z1 = z1.reshape(-1, 5, 32)
        u = u.float()
        mu2, logvar2 = self.transition(z1, u.reshape(-1, 5, 1))
        z2 = self.reparameterize2(mu2, logvar2)
        
        return self.decode(z2), mu2, logvar2

In [7]:
device = torch.device("cuda" if cuda else "cpu")

model = VAE().to(device)
#adam optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [8]:
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

## RNN new loss