In [3]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_dataset = torchvision.datasets.MNIST(root='./data',
                                           transform = transforms.ToTensor(),
                                           train=True,
                                           download = True)

test_dataset = torchvision.datasets.MNIST(root = './data',
                                          transform = transforms.ToTensor(),
                                          train = False,
                                          download = True)

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                           batch_size = 64,
                                           shuffle = True) # Batch_size, channel_size, height, width

test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                          batch_size = 64,
                                          shuffle = False)



In [4]:
sample = iter(train_loader)
images,labels = next(sample)
print(images.shape)

torch.Size([64, 1, 28, 28])


### The encoder and decoder network is preety same as that of autoencoder. The whole concept introduced in variational autoencoder is the reparamaterization trick. If we just sample our latent (z) from normal distribution with calculated mean and variance, we cant do the backpropagation, cause this step is stochastic and not deterministic. so we use the property or shifting and scaling of normal distribution and sample a random noise from standard distribution and then shift by mean and scale by standard deviation.

In [19]:
# linear variational autoencoder

class Linearvariationalautoencoder(nn.Module):

  def __init__(self,latent_dim=2):
    super().__init__()

    self.encoder = nn.Sequential(
        nn.Linear(28*28,128),
        nn.ReLU(),
        nn.Linear(128,64),
        nn.ReLU(),
        nn.Linear(64,28)
    )

    self.decoder = nn.Sequential(
        nn.Linear(latent_dim,64),
        nn.ReLU(),
        nn.Linear(64,128),
        nn.ReLU(),
        nn.Linear(128,28*28),
        nn.Sigmoid()
    )

    self.mu = nn.Linear(28, latent_dim)
    self.logvar = nn.Linear(28, latent_dim)

  def forward(self,x):
    x = x.flatten(1)
    x = self.encoder(x)

    mu = self.mu(x)
    logvar = self.logvar(x)
    std = torch.exp(0.5 *logvar)

    noise = torch.randn_like(std,device = std.device)
    z = mu + noise * std

    x_hat = self.decoder(z)
    x_hat = x_hat.reshape(64,1,28,28)
    return x_hat,mu,logvar

model = Linearvariationalautoencoder(2).to(device)
model

Linearvariationalautoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=28, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=2, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=784, bias=True)
    (5): Sigmoid()
  )
  (mu): Linear(in_features=28, out_features=2, bias=True)
  (logvar): Linear(in_features=28, out_features=2, bias=True)
)

In [21]:
test_x = torch.rand([64,1,28,28]).to(device)
x_hat,mu,logvar = model(test_x)
print(x_hat.shape)
print(mu.shape)
print(logvar.shape)

torch.Size([64, 1, 28, 28])
torch.Size([64, 2])
torch.Size([64, 2])
