In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets,transforms
from torch.utils.data import DataLoader

In [3]:
class VAE(nn.Module):
  def __init__(self,latent_dim):
    super(VAE,self).__init__()
    self.latent_dim = latent_dim

    self.encoder = nn.Sequential(
        nn.Conv2d(1,32,kernel_size=3,stride=2,padding=1),
        nn.Relu(),
        nn.nn.Conv2d(32,64, kernel_size=3,stride=2,padding=1),
        nn.Relu(),
        nn.Flatten(),
        nn.Linear(64*7*7,128),
        nn.Relu(),
    )

    self.z_mean=nn.linear(128,latent_dim)
    self.z_log_var=nn.linear(128,latent_dim)

    self.decoder = nn.Sequential(
        nn.Linear(latent_dim,128),
        nn.Relu(),
        nn.Linear(128,64*7*7),
        nn.Relu(),
        nn.Unflatten(1,(64,7,7)),
        nn.ConvTranspose2d(64,32,kernel_size=3,stride=2,padding=1,output_padding=1),
        nn.Relu(),
        nn.ConvTranspose2d(32,1,kernel_size=3,stride=2,padding=1,output_padding=1),
        nn.Sigmoid(),
    )

  def encoder(self,x):
    h=self.encoder(x)
    mu=self.z_mean(h)
    log_var=self.z_log_var(h)
    return mu,log_var

  def reparamet(self,mu ,log_var):
    std=torch.exp(0.5*log_var)
    eps=torch.randn_like(std)
    z=mu+eps*std
    return z

  def decode(self,z):
    x_hat=self.decoder(z)
    return x_hat

  def forward(self,x):
    mu,log_var=self.encoder(x)
    z=self.reparamet(mu,log_var)
    x_hat=self.decode(z)
    return x_hat,mu,log_var

In [5]:
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True)

In [6]:
train_dataset

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
           )

In [7]:
test_dataset

Dataset MNIST
    Number of datapoints: 10000
    Root location: ./data
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
           )

In [8]:
test_loader

<torch.utils.data.dataloader.DataLoader at 0x7b5ec887a650>