In [12]:
import os
import sys
cwd = os.getcwd()
#add CIFAR10 data in the environment
sys.path.append(cwd + '/../cifar10')

#Numpy is linear algebra lbrary
import numpy as np
# Matplotlib is a visualizations library 
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.features =16
        
        
        # encoder
        self.encoder = nn.Sequential(nn.Conv2d(3, 6, 5), 
                                     nn.ReLU(),
                                     nn.MaxPool2d(2, 2), 
                                     nn.Conv2d(6, 16, 5),
                                     nn.ReLU(),
                                     nn.MaxPool2d(2, 2),
                                     nn.Flatten(),
                                     nn.Linear(16 * 5 * 5, 120), 
                                     nn.ReLU(),
                                     nn.Linear(120, self.features * 2),
                                     nn.ReLU())

        # decoder
        self.decoder_linear = nn.Sequential(nn.Linear(in_features=self.features, out_features=120),
                                            nn.ReLU(),
                                            nn.Linear(in_features=120, out_features=16 * 5 * 5),
                                            nn.ReLU())
        self.decoder_conv = nn.Sequential(nn.ConvTranspose2d(16, 10, 8, padding=1),
                                          nn.ReLU(),
                                          nn.ConvTranspose2d(10, 6, 21, padding=1),
                                          nn.ReLU(),
                                          nn.ConvTranspose2d(6, 3, 5),
                                          nn.Flatten(),
                                         nn.Sigmoid())
                                          
                                     
    def forward(self, x):
        # encoding
        x = self.encoder(x).view(-1, 2, self.features)
        # get `mu` and `log_var`
        mu = x[:, 0, :]  # the first feature values as mean
        log_var = x[:, 1, :]  # the other feature values as variance
        # get the latent vector through reparameterization
        z = self.reparameterize(mu, log_var)

        # decoding
        x = self.decoder_linear(z)
        x = x.view(-1, 16, 5, 5)
        reconstruction = self.decoder_conv(x)
        return reconstruction, mu, log_var

    def reparameterize(self, mu, log_var):
            """
            :param mu: mean from the encoder's latent space
            :param log_var: log variance from the encoder's latent space
            """
            std = torch.exp(0.5 * log_var)  # standard deviation
            eps = torch.randn_like(std)  # generate sample of the same size
            sample = mu + (eps * std)  # sampling as if coming from the input space
            return sample

def final_loss(bce_loss, mu, logvar):
    """
    This function will add the reconstruction loss (BCELoss) and the
    KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param bce_loss: recontruction loss
    :param mu: the mean from the latent vector
    :param logvar: log variance from the latent vector
    """
    BCE = bce_loss
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

def train(model,training_data):

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.BCELoss(reduction='sum')

    running_loss = 0.0

    for epoch in range(1):  # loop over the dataset multiple times

        for i, data in enumerate(training_data, 0):
            inputs, _ = data
            inputs = inputs.view(-1, 3, 32, 32)

            optimizer.zero_grad()
            reconstruction, mu, logvar = model(inputs)
            inputs = inputs.view(inputs.size(0), -1)
            bce_loss = criterion(reconstruction, inputs)
            loss = final_loss(bce_loss, mu, logvar)
            running_loss += loss.item()
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0

    PATH = './cifar_net.pth'
    torch.save(model.state_dict(), PATH)

    print('Finished Training')

In [6]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                            shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                            shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [13]:
model = VAE()
train(model, trainloader)

  warn(f"Failed to load image Python extension: {e}")
  warn(f"Failed to load image Python extension: {e}")


[1,  2000] loss: -307884.120
[1,  4000] loss: -364185.945
[1,  6000] loss: -391470.040
[1,  8000] loss: -404930.131
[1, 10000] loss: -408856.913
[1, 12000] loss: -411827.272
Finished Training


In [None]:
x = torch.rand(16*5*5)
x = x.view(-1, 16, 5, 5)
print(x.size())

torch.Size([1, 16, 5, 5])
