In [64]:
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt
from torch import distributions
import os
from torchvision.utils import save_image

In [65]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     # Normalize the images to be -0.5, 0.5
     transforms.Normalize(0.5, 1)]
    )

In [66]:
mnist = torchvision.datasets.MNIST('data/', download=True, transform=transform)

In [82]:
input_dim = 28 * 28
batch_size = 128
num_epochs = 100
learning_rate = 0.001
hidden_size = 512
latent_size = 8

In [68]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [69]:
dataloader = torch.utils.data.DataLoader(
    mnist, batch_size=batch_size,
    shuffle=True, 
    pin_memory=torch.cuda.is_available())

In [70]:

print('Number of samples: ', len(mnist))

Number of samples:  60000


In [72]:
encoder = Encoder(input_dim, hidden_size, latent_size)
decoder = Decoder(latent_size, hidden_size, input_dim)

vae = VAE(encoder, decoder).to(device)

In [73]:
print(vae)

VAE(
  (encoder): Encoder(
    (linear1): Linear(in_features=784, out_features=512, bias=True)
    (linear2): Linear(in_features=512, out_features=512, bias=True)
    (enc_mu): Linear(in_features=512, out_features=8, bias=True)
    (enc_log_sigma): Linear(in_features=512, out_features=8, bias=True)
  )
  (decoder): Decoder(
    (linear1): Linear(in_features=8, out_features=512, bias=True)
    (linear2): Linear(in_features=512, out_features=784, bias=True)
  )
)


In [74]:
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)

In [75]:
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [76]:
class Encoder(torch.nn.Module):
    def __init__(self, D_in, H, latent_size):
        super(Encoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, H)
        self.enc_mu = torch.nn.Linear(H, latent_size)
        self.enc_log_sigma = torch.nn.Linear(H, latent_size)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        mu = self.enc_mu(x)
        log_sigma = self.enc_log_sigma(x)
        sigma = torch.exp(log_sigma)
        return torch.distributions.Normal(loc=mu, scale=sigma)


class Decoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Decoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)
        

    def forward(self, x):
        x = F.relu(self.linear1(x))
        # mu = torch.tanh(self.linear2(x))
        return F.sigmoid(self.linear2(x))

class VAE(torch.nn.Module):
    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, state):
        q_z = self.encoder(state)
        z = q_z.rsample()
        return self.decoder(z), q_z

In [83]:
for epoch in range(num_epochs):
    for data in dataloader:
        inputs, _ = data
        inputs = inputs.view(-1, input_dim).to(device)
        optimizer.zero_grad()
        p_x, q_z = vae(inputs)
        # log_prob(value)是计算value在定义的正态分布（mean,1）中对应的概率的对数，
        # 取负对数作为loss，所以对应的概率越大则loss越小，优化降低loss也就是让x对应的概率密度加大
        # log_likelihood = p_x.log_prob(inputs).sum(-1).mean() 好像不太好用
        log_likelihood  = F.binary_cross_entropy(p_x, inputs, reduction="sum")
        kl = - 0.5 * torch.sum(1 + q_z.variance - q_z.mean.pow(2) - q_z.variance.exp())

        loss = log_likelihood + kl
        loss.backward()
        optimizer.step()
        l = loss.item()
    print(epoch, l, log_likelihood.item(), kl.item())

0 -2732636.25 -2733764.25 1127.952880859375
1 -2763952.75 -2764994.5 1041.76220703125
2 -2834826.5 -2836146.75 1320.203857421875
3 -2746087.75 -2747054.0 966.3382568359375
4 -2791350.25 -2792410.5 1060.2666015625
5 -2762260.0 -2763240.5 980.5943603515625
6 -2781321.25 -2782207.0 885.6624755859375
7 -2761483.75 -2762442.5 958.6341552734375
8 -2861044.75 -2861969.0 924.1393432617188
9 -2826138.5 -2827117.75 979.205810546875
10 -2819701.5 -2820668.5 966.9071655273438
11 -2844188.75 -2845157.25 968.5294799804688
12 -2844634.75 -2845560.25 925.5323486328125
13 -2822638.25 -2823618.75 980.515625
14 -2873436.75 -2874352.75 916.0518798828125
15 -2857629.5 -2858460.0 830.5635986328125
16 -2875819.0 -2876661.75 842.7885131835938
17 -2829526.25 -2830415.5 889.1818237304688
18 -2887804.5 -2888700.0 895.4381103515625
19 -2840850.75 -2841713.75 863.0498046875
20 -2871963.25 -2872825.0 861.8466796875
21 -2879090.0 -2879868.0 777.9322509765625
22 -2899586.5 -2900390.5 803.8920288085938
23 -2854073.25 

In [85]:
with torch.no_grad():
    # 保存示例图片
    # 生成一个随机张量
    z = torch.randn(16, 784).to(device)
    # 将模型结果重新调整成批量，通道，大小的形状
    out,_ = vae(z)
    a = out.view(-1, 1, 28, 28)
   # 生成一组样本图象
    save_image(a, os.path.join(sample_dir, '3-50测试.png'.format(epoch+1)))

In [61]:
out.mean

tensor([[-0.5004, -0.5003, -0.4999,  ..., -0.5004, -0.5004, -0.4999],
        [-0.5004, -0.5001, -0.4997,  ..., -0.5001, -0.5001, -0.4999],
        [-0.5000, -0.4999, -0.4998,  ..., -0.4999, -0.4999, -0.5000],
        ...,
        [-0.5004, -0.5003, -0.5000,  ..., -0.5005, -0.5005, -0.5000],
        [-0.4994, -0.4994, -0.4992,  ..., -0.4996, -0.4996, -0.4995],
        [-0.5000, -0.4999, -0.4999,  ..., -0.5001, -0.5001, -0.4999]],
       device='cuda:0')