$$ Cross\ Entropy\ loss\ =\  -\sum {PlogQ} $$
$$ KL\ Divergence\ Loss\ =\ -\frac{1}{2}(1\ +\ log(\sigma ^{2}) - \mu^{2} - \sigma^{2})\ for\ Gaussian\ distribution$$ 

In [None]:
!wget https://github.com/MorvanZhou/PyTorch-Tutorial/raw/master/tutorial-contents-notebooks/mnist/processed/training.pt

--2021-10-25 13:25:33--  https://github.com/MorvanZhou/PyTorch-Tutorial/raw/master/tutorial-contents-notebooks/mnist/processed/training.pt
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/MorvanZhou/PyTorch-Tutorial/master/tutorial-contents-notebooks/mnist/processed/training.pt [following]
--2021-10-25 13:25:33--  https://raw.githubusercontent.com/MorvanZhou/PyTorch-Tutorial/master/tutorial-contents-notebooks/mnist/processed/training.pt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 47520431 (45M) [application/octet-stream]
Saving to: ‘training.pt’


2021-10-25 13:25:34 (164 MB/s) - ‘training.pt

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam

In [None]:
device = torch.device('cuda')
X_train, _ = torch.load('training.pt')
X_train = X_train.reshape(-1,1,28,28)

In [None]:
bs = 512
train_data_loader_vae = DataLoader(TensorDataset(X_train.float(), X_train.float()), batch_size = 64, shuffle=True)

In [None]:
class VAE(nn.Module):

    def __init__(self):

        super(VAE, self).__init__()
        self.enc_cv1 = nn.Conv2d(in_channels=1, out_channels=4*8, kernel_size= 4, padding = 1, stride=2)
        self.enc_cv2 = nn.Conv2d(in_channels=4*8, out_channels=8*8, kernel_size= 4, padding = 1, stride=2)
        self.enc_cv3 = nn.Conv2d(in_channels=8*8, out_channels=16*8, kernel_size= 3, padding = 0, stride=2)
        self.enc_cv4 = nn.Conv2d(in_channels=16*8, out_channels=32*8, kernel_size= 2, padding = 0, stride=2)
        self.enc_mean = nn.Linear(in_features = 32*8, out_features=20)
        self.enc_var = nn.Linear(in_features=32*8, out_features = 20)

        self.dec_fc = nn.Linear(in_features=20, out_features= 32*8)
        self.dec_cv1 = nn.ConvTranspose2d(in_channels=32*8, out_channels = 16*8, stride = 2, kernel_size = 3, padding = 0) 
        self.dec_cv2 = nn.ConvTranspose2d(in_channels=16*8, out_channels = 8*8, stride = 2, kernel_size = 3, padding = 0) 
        self.dec_cv3 = nn.ConvTranspose2d(in_channels=8*8, out_channels = 4*8, stride = 2, kernel_size = 4, padding = 1) 
        self.dec_cv4 = nn.ConvTranspose2d(in_channels=4*8, out_channels = 1, stride = 2, kernel_size = 4, padding = 1)

    def encoder(self, x):

        x = F.relu(self.enc_cv1(x))
        x = F.relu(self.enc_cv2(x))
        x = F.relu(self.enc_cv3(x))
        x = F.relu(self.enc_cv4(x))

        x = x.view(x.size()[0], -1)

        mu = self.enc_mean(x)
        log_var = self.enc_var(x)

        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        z = mu + (eps*std) 

        return z, mu, log_var

    
    def decoder(self, x):

        x = F.relu(self.dec_fc(x))
        x = x.reshape(-1,32*8,1,1)

        x = F.relu(self.dec_cv1(x))
        x = F.relu(self.dec_cv2(x))
        x = F.relu(self.dec_cv3(x))
        x = F.relu(self.dec_cv4(x))

        return x

    def forward(self, x):

        enc_out, mu, log_var = self.encoder(x)
        dec_out = self.decoder(enc_out)

        return dec_out, mu, log_var

In [None]:
model_vae = VAE()
device = torch.device("cpu")
model_vae = model_vae.to(device)
opt = Adam(model_vae.parameters())

In [None]:
for epoch in range(200):

    training_loss_vae = 0.0
    for data_vae, _ in train_data_loader_vae:
        data_vae = data_vae.to(device)

        opt.zero_grad()
        output, mu, log_var = model_vae(data_vae)

        loss_vae = F.mse_loss(output, data_vae, reduction = 'mean') - \
        0.5*torch.mean(1+log_var - mu.pow(2) - log_var.exp())*(1/(epoch+1))

        loss_vae.backward()
        opt.step()
        training_loss_vae+=loss_vae.item()

    if (epoch+1)%10==0:
        print(f"Epoch{epoch} Training loss: {training_loss}")

NameError: ignored

In [None]:
import matplotlib.pyplot as plt
with torch.no_grad():
    noise = np.random.normal(0,1, size=20)
    noise = torch.from_numpy(noise).float().to(device)
    generated_image = model_vae.decoder(noise)

    plt.imshow(generated_image.cpu().detach().numpy().reshape(28,28))