<a href="https://colab.research.google.com/github/Muhammad-Ikhwan-Fathulloh/Generative-Adversarial-Network-GAN/blob/main/dvae_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.datasets import MNIST
from torchvision import transforms

# Define data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST dataset
mnist = MNIST(root='./data', train=True, transform=transform, download=True)

# Normalize the data to [0, 1] range
mnist.data = mnist.data.float() / 255.0

# Parameters
mb_size = 32
z_dim = 5
X_dim = mnist.data.size(1) * mnist.data.size(2)  # Flattened image dimensions
h_dim = 128
lr = 1e-3
noise_factor = .25

# Define Q, P, sample_z, and other functions here

# Define Q(z|X)
Whz_mu = xavier_init(size=[h_dim, z_dim])
bhz_mu = Variable(torch.zeros(z_dim), requires_grad=True)

Whz_var = xavier_init(size=[h_dim, z_dim])
bhz_var = Variable(torch.zeros(z_dim), requires_grad=True)

def sample_z(mu, log_var):
    eps = Variable(torch.randn(mu.size()))
    return mu + torch.exp(log_var / 2) * eps

def Q(X):
    h = nn.ReLU()(X @ Wxh + bxh.repeat(X.size(0), 1))
    z_mu = h @ Whz_mu + bhz_mu.repeat(h.size(0), 1)
    z_var = h @ Whz_var + bhz_var.repeat(h.size(0), 1)
    return z_mu, z_var

# Define P(X|z)
Wzh = xavier_init(size=[z_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)

def P(z):
    h = nn.ReLU()(z @ Wzh + bzh.repeat(z.size(0), 1))
    X = nn.Sigmoid()(h @ Whx + bhx.repeat(h.size(0), 1))
    return X


# Training
params = [Wxh, bxh, Whz_mu, bhz_mu, Whz_var, bhz_var,
          Wzh, bzh, Whx, bhx]

solver = optim.Adam(params, lr=lr)
c = 0  # Counter for saving images

"""1000000"""
for it in range(100000):
    X = X.view(-1, X_dim)  # Flatten the input image
    X = Variable(X)

    # Add noise
    X_noise = X + noise_factor * Variable(torch.randn(X.size()))
    X_noise.data.clamp_(0., 1.)

    # Forward
    z_mu, z_var = Q(X_noise)
    z = sample_z(z_mu, z_var)
    X_sample = P(z)

    # Loss calculation
    recon_loss = F.binary_cross_entropy(X_sample, X, size_average=False) / mb_size
    kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var, 1))
    loss = recon_loss + kl_loss

    # Backward
    solver.zero_grad()
    loss.backward()
    solver.step()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; Loss: {:.4}'.format(it, loss.item()))

        z = Variable(torch.randn(mb_size, z_dim))
        samples = P(z).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(c).zfill(3)), bbox_inches='tight')
        c += 1
        plt.close(fig)

Iter-0; Loss: 1.369e+03
Iter-1000; Loss: -3.083e+04
Iter-2000; Loss: -3.05e+04
Iter-3000; Loss: -3.097e+04
Iter-4000; Loss: -3.153e+04
Iter-5000; Loss: -3.128e+04
Iter-6000; Loss: -3.225e+04
Iter-7000; Loss: -3.319e+04
Iter-8000; Loss: -3.502e+04
Iter-9000; Loss: -3.57e+04
Iter-10000; Loss: -3.594e+04
Iter-11000; Loss: -3.73e+04
Iter-12000; Loss: -3.909e+04
Iter-13000; Loss: -4.072e+04
Iter-14000; Loss: -4.138e+04
Iter-15000; Loss: -4.232e+04
Iter-16000; Loss: -4.208e+04
Iter-17000; Loss: -4.258e+04
Iter-18000; Loss: -4.382e+04
Iter-19000; Loss: -4.479e+04
Iter-20000; Loss: -4.604e+04
Iter-21000; Loss: -4.495e+04
Iter-22000; Loss: -4.759e+04
Iter-23000; Loss: -4.562e+04
Iter-24000; Loss: -4.848e+04
Iter-25000; Loss: -4.705e+04
Iter-26000; Loss: -4.975e+04
Iter-27000; Loss: -5.227e+04
Iter-28000; Loss: -5.208e+04
Iter-29000; Loss: -5.412e+04
Iter-30000; Loss: -5.601e+04
Iter-31000; Loss: -5.554e+04
Iter-32000; Loss: -5.419e+04
Iter-33000; Loss: -5.666e+04
Iter-34000; Loss: -5.642e+04
It