# Variational Autoencoders Examples

## General Imports

In [3]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [11]:
torch.manual_seed(1)

<torch._C.Generator at 0x7f673b270890>

## MNIST

### Load Dataset

In [4]:
# convert data to torch.FloatTensor
transform = transforms.ToTensor()

# transform=transforms.Compose([
#                             transforms.ToTensor(),
#                             transforms.Normalize((0.1307,), (0.3081,))
#                             ])

# choose the training and test datasets
train_data = datasets.MNIST(root='../data', train=True,
                               download=False, transform=transform)
test_data = datasets.MNIST(root='../data', train=False,
                              download=False, transform=transform)

Prepare data loaders

In [5]:
# number of subprocesses to use for data loading
num_workers = 0
# how many samples per batch to load
batch_size = 128
# percentage of training set to use as validation
valid_size = 0.2
# epochs
epochs = 10

In [9]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, 
                            num_workers=num_workers)

test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, 
                            num_workers=num_workers)

### Build Model

In [7]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc2_mu = nn.Linear(400, 20)
        self.fc2_logvar = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc2_mu(h1), self.fc2_logvar(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + (std * eps)

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        # flatten
        x = x.view(-1, 784)
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

Loss Function and Optimizer

In [8]:
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

### VAE Loss

In [10]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

### Train Function

In [18]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 1 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

### Test Function

In [144]:
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

### Training

In [23]:
for epoch in range(1, epochs + 1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        sample = torch.randn(64, 20).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(64, 1, 28, 28),
                   'results/sample_' + str(epoch) + '.png')





====> Epoch: 1 Average loss: 116.8271
====> Test set loss: 114.1250






====> Epoch: 2 Average loss: 113.0601
====> Test set loss: 111.2357






====> Epoch: 3 Average loss: 110.9477
====> Test set loss: 109.8243






====> Epoch: 4 Average loss: 109.6111
====> Test set loss: 109.0256




====> Epoch: 5 Average loss: 108.6077
====> Test set loss: 108.2740






====> Epoch: 6 Average loss: 107.9328
====> Test set loss: 107.1464






====> Epoch: 7 Average loss: 107.3067
====> Test set loss: 106.8490






====> Epoch: 8 Average loss: 106.7937
====> Test set loss: 106.2790






====> Epoch: 9 Average loss: 106.4710
====> Test set loss: 106.6400




====> Epoch: 10 Average loss: 106.0661
====> Test set loss: 106.0083


## FRUITS

In [2]:
import torch
import os
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [6]:
# number of subprocesses to use for data loading
num_workers = 2
# how many samples per batch to load
batch_size = 128
# percentage of training set to use as validation
# valid_size = 0.2
# epochs
epochs = 30

log_interval = 50

### Data Loaders

Image size is  100 x 100

In [7]:
root_train = "../data/fruits-360/Training"
root_test = "../data/fruits-360/Test"

In [8]:
train_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
    root_train, transform=transforms.ToTensor()),
                                           batch_size=batch_size,
                                           num_workers=num_workers,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
    root_test, transform=transforms.ToTensor()),
                                          batch_size=batch_size,
                                          num_workers=num_workers,
                                          shuffle=True)

In [9]:
train_loader.batch_size

128

In [10]:
len(train_loader)

529

In [11]:
len(train_loader.sampler)

67692

In [12]:
dataiter = iter(train_loader)
batch_data = dataiter.next()

In [13]:
len(batch_data)

2

In [14]:
batch_data[0].shape

torch.Size([128, 3, 100, 100])

In [15]:
batch_data[1].shape

torch.Size([128])

In [16]:
batch_data[1]


tensor([101, 108,  84,  97,  73,  85, 109, 129,  67,  92,   8, 123, 123,  73,
         63,  58,  97,  28,  80,  29,  44, 104,  28,  98,  91, 119,  27,  17,
         94,  76,  46, 104, 107,  55,  23,   7, 112,  21, 126,  57,  73,  68,
         78,  59,  72,  92,  19,  77,  58, 128, 104,  20, 122,  24, 118,  36,
         51,  35, 104, 117,  79, 107, 112,  45,  34,  31,  34, 120,  25,  16,
         57,   6,  85,  41,  38, 125,  61, 122,  70,  16,  21,  59,  32,  76,
        105, 128,   2,  28,  52,  29,  69,   8,  89,  67, 107,  99,  67, 111,
         96, 110,  78,  92,  27,  69,  28,  42,  52,  50, 110,  67,  76,   6,
         66,  60,  76, 129,  29,  17,  59,  22, 123,  73,  79,  30,  41,  79,
         68,  68])

### Build Model

<img src="img/VAE_Image_1.png" width="400" height="200">

In [43]:
class VAE_CNN_Fruits(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder
        self.conv1 = nn.Conv2d(3, 16, kernel_size=4, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.conv3 = nn.Conv2d(32, 16, kernel_size=5, stride=2)

        # Latent vectors mu and sigma
        self.fc1 = nn.Linear(10 * 10 * 16, 1024)
        self.fc21 = nn.Linear(1024, 1024)
        self.fc22 = nn.Linear(1024, 1024)

        # Sampling vector
        self.fc3 = nn.Linear(1024, 1024)
        self.fc4 = nn.Linear(1024, 10 * 10 * 16)

        # Decoder
        self.convT1 = nn.ConvTranspose2d(16, 32, kernel_size=5, stride=2)
        self.convT2 = nn.ConvTranspose2d(32, 16, kernel_size=5, stride=2)
        self.convT3 = nn.ConvTranspose2d(16, 3, kernel_size=4, stride=2) 
        
    def encoder(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        
        x = self.conv2(x)
        x = F.relu(x)

        x = self.conv3(x)
        x = F.relu(x)
        
        x = x.view(-1, 10 * 10 * 16)
        
        x = self.fc1(x)
        x = F.relu(x)
        
        mu = self.fc21(x)
        logvar = self.fc22(x)
        
        return mu, logvar

    
    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = std.data.new(std.size()).normal_()
        return eps.mul(std).add_(mu)


    def decoder(self, z):
        x = self.fc3(z)
        x = F.relu(x)
        
        x = self.fc4(x)
        x = F.relu(x)
        
        x = x.view(-1, 16, 10, 10)
        
        x = self.convT1(x)
        x = F.relu(x)
        
        x = self.convT2(x)
        x = F.relu(x)
        
        x = self.convT3(x)
        x = F.relu(x)

        return x

    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar

In [51]:
class VAE_Loss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse_loss = nn.MSELoss(reduction="sum")

    def forward(self, x_recon, x, mu, logvar):
        loss_MSE = self.mse_loss(x_recon, x)
        loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return loss_MSE + loss_KLD
    
# def VAE_Loss(recon_x, x, mu, logvar):
#     BCE = F.binary_cross_entropy(recon_x, x, reduction="sum")
#     # BCE = F.mse_loss(recon_x, x, size_average=False)

#     # see Appendix B from VAE paper:
#     # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
#     # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
#     KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

#     return BCE + KLD, BCE, KLD

In [52]:
model = VAE_CNN_Fruits().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [53]:
print(model)

VAE_CNN_Fruits(
  (conv1): Conv2d(3, 16, kernel_size=(4, 4), stride=(2, 2))
  (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2))
  (conv3): Conv2d(32, 16, kernel_size=(5, 5), stride=(2, 2))
  (fc1): Linear(in_features=1600, out_features=1024, bias=True)
  (fc21): Linear(in_features=1024, out_features=1024, bias=True)
  (fc22): Linear(in_features=1024, out_features=1024, bias=True)
  (fc3): Linear(in_features=1024, out_features=1024, bias=True)
  (fc4): Linear(in_features=1024, out_features=1600, bias=True)
  (convT1): ConvTranspose2d(16, 32, kernel_size=(5, 5), stride=(2, 2))
  (convT2): ConvTranspose2d(32, 16, kernel_size=(5, 5), stride=(2, 2))
  (convT3): ConvTranspose2d(16, 3, kernel_size=(4, 4), stride=(2, 2))
)


In [54]:
VAE_Loss = VAE_Loss()

###  Train

In [55]:
val_losses = []
train_losses = []

In [56]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = VAE_Loss(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),
                       loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))
    train_losses.append(train_loss / len(train_loader.dataset))

In [57]:
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += VAE_Loss(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                        recon_batch.view(batch_size, 3, 100, 100)[:n]])
                save_image(comparison.cpu(),
                           'results/Fruit_reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))
    val_losses.append(test_loss)

In [58]:
for epoch in range(1, epochs + 1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        sample = torch.randn(64, 1024).to(device)
        sample = model.decoder(sample).cpu()
        save_image(sample.view(64, 3, 100, 100),
                   'results/Fruit_sample_' + str(epoch) + '.png')

====> Epoch: 1 Average loss: 1584.7651
====> Test set loss: 754.5203
====> Epoch: 2 Average loss: 638.1795
====> Test set loss: 566.4806
====> Epoch: 3 Average loss: 511.5057
====> Test set loss: 487.4608
====> Epoch: 4 Average loss: 436.6867
====> Test set loss: 417.9083
====> Epoch: 5 Average loss: 378.9626
====> Test set loss: 363.3756
====> Epoch: 6 Average loss: 340.3447
====> Test set loss: 337.6168
====> Epoch: 7 Average loss: 315.9844
====> Test set loss: 320.3840
====> Epoch: 8 Average loss: 298.9626
====> Test set loss: 303.9112
====> Epoch: 9 Average loss: 286.9023
====> Test set loss: 297.0488
====> Epoch: 10 Average loss: 276.3297
====> Test set loss: 290.0110
====> Epoch: 11 Average loss: 266.9069
====> Test set loss: 278.1565
====> Epoch: 12 Average loss: 262.2719
====> Test set loss: 275.7483


====> Epoch: 13 Average loss: 255.9971
====> Test set loss: 268.4610
====> Epoch: 14 Average loss: 249.7846
====> Test set loss: 269.5839
====> Epoch: 15 Average loss: 245.8191
====> Test set loss: 263.1162
====> Epoch: 16 Average loss: 241.4299
====> Test set loss: 264.4838
====> Epoch: 17 Average loss: 238.1628
====> Test set loss: 260.9398
====> Epoch: 18 Average loss: 235.1436
====> Test set loss: 254.8337
====> Epoch: 19 Average loss: 232.1132
====> Test set loss: 252.9134
====> Epoch: 20 Average loss: 229.1236
====> Test set loss: 250.6922
====> Epoch: 21 Average loss: 227.5314
====> Test set loss: 248.5615
====> Epoch: 22 Average loss: 224.3366
====> Test set loss: 250.7365
====> Epoch: 23 Average loss: 222.4654
====> Test set loss: 246.9123
====> Epoch: 24 Average loss: 221.1848
====> Test set loss: 249.5986
====> Epoch: 25 Average loss: 218.7931
====> Test set loss: 249.6578


====> Epoch: 26 Average loss: 216.6648
====> Test set loss: 243.3757
====> Epoch: 27 Average loss: 215.2522
====> Test set loss: 242.1285
====> Epoch: 28 Average loss: 214.0626
====> Test set loss: 241.5115
====> Epoch: 29 Average loss: 212.4532
====> Test set loss: 242.9403
====> Epoch: 30 Average loss: 210.8139
====> Test set loss: 239.5776


## Helper Functions

In [1]:
def conv2d_dimensions(width_in, height_in, input_channels, kernel_size, kernel_filters, padding=0, stride=1):

    width_out = ((width_in - kernel_size + (2 * padding)) / stride) + 1
    height_out = ((height_in - kernel_size + (2 * padding)) / stride) + 1
    output_channels = kernel_filters
    
    return width_out, height_out, output_channels

In [2]:
def convTranspose2d_dim(width_in, height_in, input_channels, kernel_size, kernel_filters, padding=0, stride=1):
    output_channels = kernel_filters
    width_out = stride * (width_in - 1) + kernel_size - (2 * padding)
    height_out = stride * (height_in - 1) + kernel_size - (2 * padding)
    
    return width_out, height_out, output_channels

## Helper Calculations

### conv2d

In [10]:
conv2d_dimensions(100, 100, 3, 3, 16)

(98.0, 98.0, 16)

In [22]:
conv2d_dimensions(98, 98, 16, 3, 32)

(96.0, 96.0, 32)

In [23]:
conv2d_dimensions(96/2, 96/2, 32, 3, 16, padding=0, stride=1)

(46.0, 46.0, 16)

Conv2D with stride=2 for downsampling

In [34]:
conv2d_dimensions(100, 100, 3, 4, 16, padding=0, stride=2)

(49.0, 49.0, 16)

In [31]:
conv2d_dimensions(49, 49, 16, 5, 32, padding=0, stride=2)

(23.0, 23.0, 32)

In [32]:
conv2d_dimensions(23, 23, 32, 5, 16, padding=0, stride=2)

(10.0, 10.0, 16)

### ConvTranspose2D

ConvTranspose2D Dimensions calculation

In [72]:
convTranspose2d_dim(23, 23, 16, 16, 32)

(38, 38, 32)

In [73]:
convTranspose2d_dim(38, 38, 32, 30, 16)

(67, 67, 16)

In [74]:
convTranspose2d_dim(67, 67, 16, 34, 3)

(100, 100, 3)

Conv2DTranspose with stride=2

In [40]:
convTranspose2d_dim(10, 10, 16, 5, 32, padding=0, stride=2)

(23, 23, 32)

### Flatten

In [21]:
t = torch.tensor([[[1, 2],
                       [3, 4]],
                      [[5, 6],
                       [7, 8]]])
t

tensor([[[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]]])

In [22]:
torch.flatten(t)

tensor([1, 2, 3, 4, 5, 6, 7, 8])

In [23]:
torch.flatten(t, 1)

tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])

In [38]:
convTranspose2d_dim(23, 23, 32, 5, 16, padding=0, stride=2)

(49, 49, 16)

In [39]:
convTranspose2d_dim(49, 49, 16, 4, 3, padding=0, stride=2)

(100, 100, 3)

### Test BCE Loss values

INPUTS

In [22]:
input = torch.randn((2, 1), requires_grad=True)
input

tensor([[-0.9193],
        [ 0.9176]], requires_grad=True)

TARGETS

In [35]:
target = torch.rand((2, 1), requires_grad=False)
target

tensor([[0.2970],
        [0.4710]])

SIGMOIDS TO INPUTS

In [39]:
input_sig = torch.sigmoid(input)
input_sig

tensor([[0.2851],
        [0.7145]], grad_fn=<SigmoidBackward>)

In [49]:
loss = F.binary_cross_entropy(torch.sigmoid(input), target, reduction='mean')
loss

tensor(0.7151, grad_fn=<BinaryCrossEntropyBackward>)

In [46]:
-( (0.2970 * np.log(0.2851)) + (1 - 0.2970)*(np.log(1 - 0.2851)))

0.6086455012350224

In [47]:
-( (0.4710 * np.log(0.7145)) + (1 - 0.4710)*(np.log(1 - 0.7145)))

0.8214456538290854

In [50]:
(0.6086455012350224+0.8214456538290854)/2

0.7150455775320539