Skip to content

Commit

Permalink
fix bug in vae
Browse files Browse the repository at this point in the history
  • Loading branch information
L1aoXingyu committed Sep 8, 2017
1 parent 765f849 commit c1f9a37
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions 08-AutoEncoder/Variational_autoencoder.py
Expand Up @@ -12,7 +12,6 @@
from torchvision.datasets import MNIST
import os


if not os.path.exists('./vae_img'):
os.mkdir('./vae_img')

Expand All @@ -32,7 +31,7 @@ def to_img(x):
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = MNIST('./data', transform=img_transform)
dataset = MNIST('./data', transform=img_transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


Expand Down Expand Up @@ -73,7 +72,7 @@ def forward(self, x):
if torch.cuda.is_available():
model.cuda()

reconstruction_function = nn.BCELoss(size_average=False)
reconstruction_function = nn.MSELoss(size_average=False)


def loss_function(recon_x, x, mu, logvar):
Expand All @@ -93,7 +92,6 @@ def loss_function(recon_x, x, mu, logvar):

optimizer = optim.Adam(model.parameters(), lr=1e-3)


for epoch in range(num_epochs):
model.train()
train_loss = 0
Expand All @@ -111,12 +109,13 @@ def loss_function(recon_x, x, mu, logvar):
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(img), len(dataloader.dataset),
100. * batch_idx / len(dataloader),
epoch,
batch_idx * len(img),
len(dataloader.dataset), 100. * batch_idx / len(dataloader),
loss.data[0] / len(img)))

print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(dataloader.dataset)))
epoch, train_loss / len(dataloader.dataset)))
if epoch % 10 == 0:
save = to_img(recon_batch.cpu().data)
save_image(save, './vae_img/image_{}.png'.format(epoch))
Expand Down

0 comments on commit c1f9a37

Please sign in to comment.