In [3]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
import matplotlib.pyplot as plt
from tqdm import tqdm
import os.path

In [4]:
features = 16

class VAE(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()

        #encoder layers
        self.encoder1 = nn.Linear(in_features=kwargs["input_shape"], out_features=kwargs["mid_dim"])
        self.encoder2 = nn.Linear(in_features=kwargs["mid_dim"], out_features=features*2)

        #decoder layers
        self.decoder1 = nn.Linear(in_features=features, out_features=kwargs["mid_dim"])
        self.decoder2 = nn.Linear(in_features=kwargs["mid_dim"], out_features=kwargs["input_shape"])

    def reparametrize(self, mu, log_var):

        # mu: mean of the encoder's latent space distribution
        # log_var: variance from the encoder's latient space distribution

        std = torch.exp(0.5*log_var) #standard deviation. 0,5 to have a unit variance
        eps = torch.randn_like(std) #same size as std
        sample = mu + (eps*std) #we take a value of the distribution of the latent space
        return sample

    def forward(self, x):
        # encode
        x = F.relu(self.encoder1(x))
        x = self.encoder2(x).view(-1,2,features)

        #get mu and log_var
        mu = x[:, 0, :] # the first feature values as mean
        log_var = x[:, 1, :] # the other feature values as variance

        z = self.reparametrize(mu,log_var) #get a sample of the distribution

        #decode
        x = F.relu(self.decoder1(z))
        reconstruction = torch.sigmoid(self.decoder2(x))
        return reconstruction, mu, log_var, z



In [5]:
model = VAE(input_shape=784, mid_dim=512)
criterion = nn.BCELoss(reduction='sum')# MSE or CrossEntropy?
# If you assume that the latent space vector follows a normal distribution: MSE
# If you assume that tht latent space vector follows a multinominal distribution: CrossEntropy
# CrossEntropy err is NOT simetrical, it's biased towards 0.5
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [6]:
# transforms
transform = T.Compose([
    T.ToTensor(),
])

train_dataset = torchvision.datasets.MNIST(
    root='../input/data',
    train=True,
    download=True,
    transform=transform
)
test_dataset = torchvision.datasets.MNIST(
    root='../input/data',
    train=False,
    download=True,
    transform=transform
)

# training and validation data loaders
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False
)

In [7]:
for data, target in train_loader: #with this you can know the dimensions of the data
    print("Batch, channels, height, width: ", data.shape)               # batch_size = 128, channels = 1 (grey scale), image = 28x28
    break

print("How many examples (training): ", len(train_dataset))
print("How many examples (test): ", len(test_dataset))

Batch, channels, height, width:  torch.Size([32, 1, 28, 28])
How many examples (training):  60000
How many examples (test):  10000


In [8]:
def final_loss(mu, logvar, reconstruction_loss):

    KL_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) #Appendix B VAE Paper
    Reconstruction = reconstruction_loss

    return KL_divergence + Reconstruction   

In [11]:
#define a function to train the data
def fit(model, dataloader):
    model.train()
    running_loss = 0.0
    for i, data in tqdm(enumerate(dataloader), total=int(len(train_dataset)/dataloader.batch_size)):
        data, _ = data # we want the data, not the label
        data = data
        data = data.view(data.size(0), -1) #flat the data
        optimizer.zero_grad() # reset the gradients back to zero
        reconstruction, mu, logvar,_ = model(data)  # compute reconstructions
        reconstruction_loss = criterion(reconstruction, data) #calculate reconstruction loss
        loss = final_loss(mu, logvar, reconstruction_loss)# real loss: reconstruction + kl_divergence
        running_loss += loss.item() 
        loss.backward() # compute accumulated gradients
        optimizer.step() #update the weights (net.parameters)
    train_loss = running_loss/len(dataloader.dataset) # average loss
    return train_loss

In [16]:
def validate(model, dataloader):
    model.eval()
    running_loss = 0.0
    with torch.no_grad(): # in validation we don't want to update weights
        for i, data in tqdm(enumerate(dataloader), total=int(len(test_dataset)/dataloader.batch_size)):
            data, _ = data
            data = data
            data = data.view(data.size(0), -1)
            reconstruction, mu, logvar, coded = model(data)
            reconstruction_loss = criterion(reconstruction, data)
            loss = final_loss(mu, logvar, reconstruction_loss)
            running_loss += loss.item()
            # same as training but without backpropagation

            # save the last batch input and output of every epoch
            if i == int(len(test_dataset)/dataloader.batch_size) - 1:
                 num_rows = 8
                 both = torch.cat((data.view(batch_size, 1, 28, 28)[:8], 
                                   reconstruction.view(batch_size, 1, 28, 28)[:8]))
                 code = (coded.view(batch_size,1,8,2)[:8])
                 
                 current_directory = os.getcwd()
                 final_directory = os.path.join(current_directory, 'ImageOutputVAE')
                 if not os.path.exists(final_directory):
                     os.makedirs(final_directory)
                 torchvision.utils.save_image(both, final_directory + f"/Output{epoch}.png",                                                                                           nrow=num_rows)
                 torchvision.utils.save_image(code, final_directory + f"/Code{epoch}.png",                                                                                             nrow=num_rows)

    val_loss = running_loss/len(dataloader.dataset)
    return val_loss

In [17]:
epochs = 20
batch_size = 32

train_loss = []
test_loss = []
for epoch in range(epochs):
    print(f"\n Epoch {epoch+1} of {epochs}")
    train_epoch_loss = fit(model, train_loader)
    test_epoch_loss = validate(model, test_loader)
    train_loss.append(train_epoch_loss)
    test_loss.append(test_epoch_loss)
    print(f"\nTrain Loss: {train_epoch_loss:.4f}")
    print(f"Test Loss: {test_epoch_loss:.4f}")

  0%|          | 5/1875 [00:00<00:43, 42.80it/s]
 Epoch 1 of 20
100%|██████████| 1875/1875 [00:46<00:00, 39.97it/s]
313it [00:01, 186.71it/s]                         
  0%|          | 6/1875 [00:00<00:35, 52.18it/s]
Train Loss: 124.9917
Test Loss: 120.8866

 Epoch 2 of 20
100%|██████████| 1875/1875 [00:34<00:00, 54.00it/s]
313it [00:01, 191.72it/s]                         
  0%|          | 5/1875 [00:00<00:37, 49.55it/s]
Train Loss: 120.2016
Test Loss: 117.5895

 Epoch 3 of 20
100%|██████████| 1875/1875 [00:34<00:00, 53.74it/s]
313it [00:01, 187.00it/s]                         
  0%|          | 6/1875 [00:00<00:36, 51.01it/s]
Train Loss: 117.3967
Test Loss: 115.1735

 Epoch 4 of 20
100%|██████████| 1875/1875 [00:35<00:00, 52.23it/s]
313it [00:01, 182.46it/s]                         
  0%|          | 5/1875 [00:00<00:39, 47.02it/s]
Train Loss: 115.4009
Test Loss: 113.4772

 Epoch 5 of 20
100%|██████████| 1875/1875 [00:39<00:00, 47.22it/s]
313it [00:01, 166.37it/s]                       

KeyboardInterrupt: 