In [1]:
# import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
from torch.nn.modules.activation import LeakyReLU
class VAE(nn.Module):
  def __init__(self, in_channel, image_size, hidden_dim, z_dim):
    super().__init__()
    self.in_channel = in_channel
    self.image_size = image_size
    self.hidden_dim = hidden_dim
    self.z_dim = z_dim


    self.encoder = nn.Sequential(
        nn.Conv2d(in_channels=self.in_channel, out_channels=32, kernel_size=3, stride=2, padding=1, bias=False),
        nn.LeakyReLU(0.1, inplace=True),

        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1, bias=False),
        nn.LeakyReLU(0.1, inplace=True),

        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1, bias=False),
        nn.LeakyReLU(0.1, inplace=True),

        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1, bias=False),
        nn.LeakyReLU(0.1, inplace=True),

        nn.Flatten(),
        nn.Linear(256, self.hidden_dim)
    )

    self.mu = nn.Linear(self.hidden_dim, self.z_dim)
    self.sigma = nn.Linear(self.hidden_dim, self.z_dim)


    self.decoder_first_layer = nn.Linear(self.z_dim, 256)

    self.decoder = nn.Sequential(
        nn.Unflatten(dim=1, unflattened_size=(64, 2, 2)),
        
        nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, bias=False, padding=1),
        nn.LeakyReLU(0.1, inplace=True),

        nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, bias=False, padding=1),
        nn.LeakyReLU(0.1, inplace=True),

        nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, bias=False, padding=1),
        nn.LeakyReLU(0.1, inplace=True),
        
        nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, bias=False, padding=1),
        nn.LeakyReLU(0.1, inplace=True),

        nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=3, stride=2, bias=False, padding=1),
        
    ) 


  def encoder_forw(self, mzg):
    mzg = self.encoder(mzg)

    mu = self.mu(mzg)
    sigma = self.sigma(mzg)

    epsilon = torch.randn_like(sigma)
    encoded = mu + sigma * epsilon

    return encoded, mu, sigma



  def decoder_forw(self, mzg):
    mzg = self.decoder_first_layer(mzg)

    mzg = self.decoder(mzg)

    mzg = mzg[:, :, :self.image_size, :self.image_size]

    decoded = torch.sigmoid(mzg)

    return decoded


  def forward(self, mzg):
    encoded, mu, sigma = self.encoder_forw(mzg)

    # output = decoded
    output = self.decoder_forw(encoded) 

    return output, mu, sigma

In [4]:
# train function
def train(num_epochs, model, train_loader, optimizer, loss_fn):
    # training loop
    for epoch in range(num_epochs):
        loop = tqdm(enumerate(train_loader))
        for i, (x, y) in loop:
            # Forward pass
            x = x.to(device)

            x_reconst, mu, sigma = model(x)

            # loss
            reconst_loss = loss_fn(x_reconst, x)

            kl_div = - torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

            # optimizing and backpropagation
            loss = reconst_loss + kl_div
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loop.set_postfix(loss=loss.item())

In [5]:
n_channel = 1
image_size = 28
z_dim = 20
h_dim = 200
num_epochs = 5
batch_size = 64
lr = 3e-4

In [6]:
model_vae = VAE(in_channel=n_channel, image_size=image_size, hidden_dim=h_dim, z_dim=z_dim)

model_vae.to(device)

VAE(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.1, inplace=True)
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (3): LeakyReLU(negative_slope=0.1, inplace=True)
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (5): LeakyReLU(negative_slope=0.1, inplace=True)
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (7): LeakyReLU(negative_slope=0.1, inplace=True)
    (8): Flatten(start_dim=1, end_dim=-1)
    (9): Linear(in_features=256, out_features=200, bias=True)
  )
  (mu): Linear(in_features=200, out_features=20, bias=True)
  (sigma): Linear(in_features=200, out_features=20, bias=True)
  (decoder_first_layer): Linear(in_features=20, out_features=256, bias=True)
  (decoder): Sequential(
    (0): Unflatten(dim=1, unflattened_size=(64, 2, 2))
    (1): ConvTranspose2

In [7]:
# Dataset loading
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



In [8]:
# optimizer and loss functions
optimizer = torch.optim.Adam(model_vae.parameters(), lr=lr)
# loss_fn = nn.MSELoss()
loss_fn = nn.BCELoss(reduction="sum")

In [9]:
# train our model 
train(num_epochs, model_vae, train_loader, optimizer, loss_fn)

938it [00:24, 38.56it/s, loss=5e+3]
938it [00:14, 64.52it/s, loss=4.8e+3]
938it [00:14, 64.22it/s, loss=4.21e+3]
938it [00:15, 61.90it/s, loss=4.27e+3]
938it [00:14, 63.19it/s, loss=4.19e+3]


In [10]:
# Save model
FILE = "model_pytorch.pth"
torch.save(model_vae, FILE)