In [12]:
#https://github.com/atinghosh/VAE-pytorch/blob/master/VAE_CNN_BCEloss.py
import os
import sys
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
from VAE_mnist_v1 import VAE_mnist
from VAE_CIFAR_v1 import VAE_CIFAR

In [13]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
CUDA = True
SEED = 1
BATCH_SIZE = 128
LOG_INTERVAL = 10
EPOCHS = 25
no_of_sample = 10
MNIST = False


ZDIMS = 20

torch.manual_seed(SEED)
if CUDA:
    torch.cuda.manual_seed(SEED)

In [14]:
kwargs = {'num_workers': 1, 'pin_memory': True} if CUDA else {}

if MNIST:
    print("MNIST")
    train_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist', train=True, download=True,transform=transforms.ToTensor()),
        batch_size=BATCH_SIZE, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist', train=False, transform=transforms.ToTensor()),
        batch_size=BATCH_SIZE, shuffle=True, **kwargs)
else:
    print("CIFAR")
    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10('./CIFAR10', train=True, download=True,transform=transforms.ToTensor()),
        batch_size=BATCH_SIZE, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(datasets.CIFAR10('./CIFAR10', train=False, transform=transforms.ToTensor()),
        batch_size=BATCH_SIZE, shuffle=True, **kwargs)
    
# elif DATASET == 2:
#     print("CIFAR10_PROCESSED")
#     processed_CIFAR10_data = datasets.ImageFolder(root='cifar10_processed/', transform=transforms.ToTensor())
#     print(processed_CIFAR10_data.shape)
#     train_loader = torch.utils.data.DataLoader(datasets.CIFAR10('./CIFAR10', train=True, download=True,transform=transforms.ToTensor()),
#         batch_size=BATCH_SIZE, shuffle=True, **kwargs)

#     test_loader = torch.utils.data.DataLoader(datasets.CIFAR10('./CIFAR10', train=False, transform=transforms.ToTensor()),
#         batch_size=BATCH_SIZE, shuffle=True, **kwargs)

CIFAR10_PROCESSED


AttributeError: 'ImageFolder' object has no attribute 'shape'

In [4]:
if MNIST:
    print("MNIST")
    model = VAE_mnist(ZDIMS, BATCH_SIZE, no_of_sample)
else:
    print("CIFAR10")
    model = VAE_CIFAR(ZDIMS, BATCH_SIZE, no_of_sample)
if CUDA:
    model.cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

CIFAR10


In [5]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = Variable(data)
        if CUDA:
            data = data.cuda()
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = model.loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % LOG_INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),
                       loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.9f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [8]:
def test(epoch):
    model.eval()
    test_loss = 0

    # each data is of BATCH_SIZE (default 128) samples
    for i, (data, _) in enumerate(test_loader):
        if CUDA:
            # make sure this lives on the GPU
            data = data.cuda()

        # we're only going to infer, so no autograd at all required: volatile=True
        data = Variable(data, volatile=True)
        recon_batch, mu, logvar = model(data)
        test_loss += model.loss_function(recon_batch, data, mu, logvar).item()
        if i == 0:
            n = min(data.size(0), 8)
            # for the first 128 batch of the epoch, show the first 8 input digits
            # with right below them the reconstructed output digits
            if MNIST:
                comparison = torch.cat([data[:n],
                                        recon_batch.view(BATCH_SIZE, 1, 28, 28)[:n]])
                save_image(comparison.data.cpu(),
                           './mnist/reconstruction_' + str(epoch) + '.png', nrow=n)
            else:
                comparison = torch.cat([data[:n],
                                        recon_batch.view(BATCH_SIZE, 3, 32, 32)[:n]])
                save_image(comparison.data.cpu(),
                           './CIFAR10/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [11]:
for epoch in range(1, EPOCHS + 1):
    train(epoch)
    test(epoch)

    # 64 sets of random ZDIMS-float vectors, i.e. 64 locations / MNIST
    # digits in latent space
    sample = Variable(torch.randn(64, ZDIMS))
    if CUDA:
        sample = sample.cuda()
    sample = model.decode(sample).cpu()

    # save out as an 8x8 matrix of MNIST`ii digits
    # this will give you a visual idea of how well latent space can generate things
    # that look like digits
    if MNIST:
        save_image(sample.data.view(64, 1, 28, 28),'./mnist/reconstruction' + str(epoch) + '.png')
    else:
        save_image(sample.data.view(64, 3, 32, 32),'./CIFAR10/reconstruction' + str(epoch) + '.png')



====> Epoch: 1 Average loss: 0.004839524


  if sys.path[0] == '':


====> Test set loss: 0.0049
====> Epoch: 2 Average loss: 0.004832656
====> Test set loss: 0.0048
====> Epoch: 3 Average loss: 0.004828928
====> Test set loss: 0.0048
====> Epoch: 4 Average loss: 0.004826477
====> Test set loss: 0.0048
====> Epoch: 5 Average loss: 0.004824585


====> Test set loss: 0.0048
====> Epoch: 6 Average loss: 0.004823194
====> Test set loss: 0.0048
====> Epoch: 7 Average loss: 0.004821385
====> Test set loss: 0.0048
====> Epoch: 8 Average loss: 0.004819659
====> Test set loss: 0.0048
====> Epoch: 9 Average loss: 0.004818147


====> Test set loss: 0.0048
====> Epoch: 10 Average loss: 0.004818260
====> Test set loss: 0.0048
====> Epoch: 11 Average loss: 0.004816743
====> Test set loss: 0.0048
====> Epoch: 12 Average loss: 0.004816029
====> Test set loss: 0.0048


====> Epoch: 13 Average loss: 0.004814775
====> Test set loss: 0.0048
====> Epoch: 14 Average loss: 0.004814038
====> Test set loss: 0.0048
====> Epoch: 15 Average loss: 0.004813535
====> Test set loss: 0.0048
====> Epoch: 16 Average loss: 0.004812820
====> Test set loss: 0.0048


====> Epoch: 17 Average loss: 0.004812304
====> Test set loss: 0.0048
====> Epoch: 18 Average loss: 0.004811654
====> Test set loss: 0.0048
====> Epoch: 19 Average loss: 0.004811681
====> Test set loss: 0.0048
====> Epoch: 20 Average loss: 0.004811005
====> Test set loss: 0.0048


====> Epoch: 21 Average loss: 0.004810778
====> Test set loss: 0.0048
====> Epoch: 22 Average loss: 0.004810240
====> Test set loss: 0.0048
====> Epoch: 23 Average loss: 0.004810057
====> Test set loss: 0.0048
====> Epoch: 24 Average loss: 0.004809578
====> Test set loss: 0.0048


====> Epoch: 25 Average loss: 0.004808904
====> Test set loss: 0.0048


In [None]:
torch.save(model, "vae.pt")

In [None]:
num_interpolation_points = 10
sample = Variable(torch.randn(2, ZDIMS))
first_point = sample[0]
last_point = sample[1]
interpolation_points_list = []
for i in np.linspace(0,1,num_interpolation_points):
    new_interpolation_point = (1-i)*first_point+i*last_point
    interpolation_points_list.append(new_interpolation_point)

interpolation_sample = Variable(torch.stack(interpolation_points_list))
if CUDA:
    interpolation_sample = interpolation_sample.cuda()
interpolation_sample = model.decode(interpolation_sample).cpu()
save_image(interpolation_sample.data.view(num_interpolation_points, 1, 28, 28),'./mnist/interpolation.png')