In [1]:
#https://github.com/atinghosh/VAE-pytorch/blob/master/VAE_CNN_BCEloss.py
import os
import sys
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
from VAE_mnist_v1 import VAE_mnist
from VAE_CIFAR_v1 import VAE_CIFAR_v1
from VAE_CIFAR_v2 import VAE_CIFAR_v2

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
CUDA = True
SEED = 1
BATCH_SIZE = 128
LOG_INTERVAL = 100
EPOCHS = 25
no_of_sample = 10
DATASET = 1


ZDIMS = 20

torch.manual_seed(SEED)
if CUDA:
    torch.cuda.manual_seed(SEED)

In [3]:
kwargs = {'num_workers': 1, 'pin_memory': True} if CUDA else {}

if DATASET == 0:
    print("MNIST")
    train_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist', train=True, download=True,transform=transforms.ToTensor()),
        batch_size=BATCH_SIZE, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist', train=False, transform=transforms.ToTensor()),
        batch_size=BATCH_SIZE, shuffle=True, **kwargs)
elif DATASET == 1:
    print("CIFAR")
    train_loader = torch.utils.data.DataLoader(datasets.CIFAR10('./CIFAR10', train=True, download=True,transform=transforms.ToTensor()),
        batch_size=BATCH_SIZE, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(datasets.CIFAR10('./CIFAR10', train=False, transform=transforms.ToTensor()),
        batch_size=BATCH_SIZE, shuffle=True, **kwargs)
    
elif DATASET == 2:
    print("CIFAR10_PROCESSED")
    processed_CIFAR10_data_train = datasets.ImageFolder(root='cifar10_processed_train/', transform=transforms.ToTensor())
    processed_CIFAR10_data_test = datasets.ImageFolder(root='cifar10_processed_test/', transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(processed_CIFAR10_data_train, batch_size=BATCH_SIZE, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(processed_CIFAR10_data_train, batch_size=BATCH_SIZE, shuffle=True, **kwargs)
#     processed_CIFAR10_data = datasets.ImageFolder(root='cifar10_processed/', transform=transforms.ToTensor())

#     loader = torch.utils.data.DataLoader(processed_CIFAR10_data, batch_size=1, shuffle=True, **kwargs)


#     image_list = []
#     label_list = []
#     for batch_ndx, sample in enumerate(loader):
#         if batch_ndx == 0:
#             print("loader", sample)
#         image_list.append(sample[0])
#         label_list.append(sample[1])

#     train_percentage = 0.8
#     train_image_list = image_list[:int(train_percentage*len(image_list))]
#     test_image_list = image_list[int(train_percentage*len(image_list)):]
#     train_label_list = label_list[:int(train_percentage*len(label_list))]
#     test_label_list = label_list[int(train_percentage*len(label_list)):]

#     train_tensor_image = torch.stack(train_image_list)
#     train_tensor_label = torch.stack(train_label_list)
#     train_list = [train_tensor_image, train_tensor_label]
#     train_loader = torch.utils.data.DataLoader(train_list, batch_size=1,shuffle=True,**kwargs)


#     test_tensor_image = torch.stack(test_image_list)
#     test_tensor_label = torch.stack(test_label_list)
#     test_list = [test_tensor_image, test_tensor_label]
#     test_loader = torch.utils.data.DataLoader(test_list, batch_size=1,shuffle=True,**kwargs)
    
#     for idx, sample in enumerate(test_loader):
#         if idx == 0:
#             print("train_loader", sample)


CIFAR
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./CIFAR10\cifar-10-python.tar.gz


100.0%

In [5]:
if DATASET == 0:
    print("MNIST")
    model = VAE_mnist(ZDIMS, BATCH_SIZE, no_of_sample)
elif DATASET == 1:
    print("CIFAR10")
    model = VAE_CIFAR_v1(ZDIMS, BATCH_SIZE, no_of_sample)
elif DATASET == 2:
    print("CIFAR10_PROCESSED")
    model = VAE_CIFAR_v2(ZDIMS, BATCH_SIZE, no_of_sample)
if CUDA:
    model.cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

CIFAR10


In [6]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = Variable(data)
        if CUDA:
            data = data.cuda()
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = model.loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % LOG_INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),
                       loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.9f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [7]:
def test(epoch):
    model.eval()
    test_loss = 0

    # each data is of BATCH_SIZE (default 128) samples
    for i, (data, _) in enumerate(test_loader):
        if CUDA:
            # make sure this lives on the GPU
            data = data.cuda()

        # we're only going to infer, so no autograd at all required: volatile=True
        data = Variable(data, volatile=True)
        recon_batch, mu, logvar = model(data)
        test_loss += model.loss_function(recon_batch, data, mu, logvar).item()
        if i == 0:
            n = min(data.size(0), 8)
            # for the first 128 batch of the epoch, show the first 8 input digits
            # with right below them the reconstructed output digits
            if DATASET == 0:
                comparison = torch.cat([data[:n],
                                        recon_batch.view(BATCH_SIZE, 1, 28, 28)[:n]])
                save_image(comparison.data.cpu(),
                           './mnist/reconstruction_' + str(epoch) + '.png', nrow=n)
            elif DATASET == 1:
                comparison = torch.cat([data[:n],
                                        recon_batch.view(BATCH_SIZE, 3, 32, 32)[:n]])
                save_image(comparison.data.cpu(),
                           './CIFAR10/reconstruction_' + str(epoch) + '.png', nrow=n)
            elif DATASET == 2:
                comparison = torch.cat([data[:n],
                                        recon_batch.view(BATCH_SIZE, 3, 100, 100)[:n]])
                save_image(comparison.data.cpu(),
                           './cifar10_processed/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [8]:
for epoch in range(1, EPOCHS + 1):
    train(epoch)
    test(epoch)

    # 64 sets of random ZDIMS-float vectors, i.iie. 64 locations / MNIST
    # digits in latent space
    sample = Variable(torch.randn(64, ZDIMS))
    if CUDA:
        sample = sample.cuda()
    sample = model.decode(sample).cpu()

    # save out as an 8x8 matrix of MNIST`ii digits
    # this will give you a visual idea of how well latent space can generate things
    # that look like digits
    if DATASET == 0:
        save_image(sample.data.view(64, 1, 28, 28),'./mnist/reconstruction' + str(epoch) + '.png')
    elif DATASET == 1:
        save_image(sample.data.view(64, 3, 32, 32),'./CIFAR10/reconstruction' + str(epoch) + '.png')
    elif DATASET == 2:
        save_image(sample.data.view(64, 3, 100, 100),'./cifar10_processed/reconstruction' + str(epoch) + '.png')



====> Epoch: 1 Average loss: 0.004944606


  if sys.path[0] == '':


====> Test set loss: 0.0049
====> Epoch: 2 Average loss: 0.004824667
====> Test set loss: 0.0048
====> Epoch: 3 Average loss: 0.004815027
====> Test set loss: 0.0048
====> Epoch: 4 Average loss: 0.004806830
====> Test set loss: 0.0048
====> Epoch: 5 Average loss: 0.004803484
====> Test set loss: 0.0048
====> Epoch: 6 Average loss: 0.004799489
====> Test set loss: 0.0048
====> Epoch: 7 Average loss: 0.004796285
====> Test set loss: 0.0048
====> Epoch: 8 Average loss: 0.004792518
====> Test set loss: 0.0048
====> Epoch: 9 Average loss: 0.004788054
====> Test set loss: 0.0048
====> Epoch: 10 Average loss: 0.004784148
====> Test set loss: 0.0048
====> Epoch: 11 Average loss: 0.004782665
====> Test set loss: 0.0048
====> Epoch: 12 Average loss: 0.004781697
====> Test set loss: 0.0048
====> Epoch: 13 Average loss: 0.004780883
====> Test set loss: 0.0048
====> Epoch: 14 Average loss: 0.004778994
====> Test set loss: 0.0048
====> Epoch: 15 Average loss: 0.004778194
====> Test set loss: 0.0048


In [8]:
torch.save(model, "vae.pt")

In [14]:
num_interpolation_points = 16
sample = Variable(torch.randn(2, ZDIMS))
first_point = sample[0]
last_point = sample[1]
interpolation_points_list = []
for i in np.linspace(0,1,num_interpolation_points):
    new_interpolation_point = (1-i)*first_point+i*last_point
    interpolation_points_list.append(new_interpolation_point)

interpolation_sample = Variable(torch.stack(interpolation_points_list))
print(interpolation_sample.shape)
if CUDA:
    interpolation_sample = interpolation_sample.cuda()
interpolation_sample = model.decode(interpolation_sample).cpu()
print(interpolation_sample.shape)
if DATASET == 0:
    save_image(interpolation_sample.data.view(num_interpolation_points, 1, 28, 28),'./mnist/interpolation.png')
else:
    save_image(interpolation_sample.data.view(num_interpolation_points, 3, 32, 32),'./CIFAR10/interpolation.png')
    

torch.Size([16, 20])
torch.Size([16, 3, 1024])


