<a href="https://colab.research.google.com/github/Muhammad-Ikhwan-Fathulloh/Generative-Adversarial-Network-GAN/blob/main/aae_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [44]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from torchvision.datasets import MNIST
from torchvision import transforms

# Define data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST dataset
mnist = MNIST(root='./data', train=True, transform=transform, download=True)

# Normalize the data to [0, 1] range
mnist.data = mnist.data.float() / 255.0

# Parameters
mb_size = 32
z_dim = 5
X_dim = mnist.data.size(1) * mnist.data.size(2)  # Flattened image dimensions
h_dim = 128
lr = 1e-3

# Encoder
Q = nn.Sequential(
    nn.Linear(X_dim, h_dim),
    nn.ReLU(),
    nn.Linear(h_dim, z_dim)
)

# Decoder
P = nn.Sequential(
    nn.Linear(z_dim, h_dim),
    nn.ReLU(),
    nn.Linear(h_dim, X_dim),
    nn.Sigmoid()
)

# Discriminator
D = nn.Sequential(
    nn.Linear(z_dim, h_dim),
    nn.ReLU(),
    nn.Linear(h_dim, 1),
    nn.Sigmoid()
)

def reset_grad():
    Q.zero_grad()
    P.zero_grad()
    D.zero_grad()

def sample_X(size):
    indices = np.random.randint(0, len(mnist), size)
    X = mnist.data[indices].view(size, -1).float()
    return Variable(X)

Q_solver = optim.Adam(Q.parameters(), lr=lr)
P_solver = optim.Adam(P.parameters(), lr=lr)
D_solver = optim.Adam(D.parameters(), lr=lr)

"""1000000"""
for it in range(10000):
    X = sample_X(mb_size)

    """ Reconstruction phase """
    z_sample = Q(X)
    X_sample = P(z_sample)

    # Clip values to be within [0, 1]
    X_sample = X_sample.clamp(0, 1)

    # Use BCELoss for binary cross entropy
    recon_loss = nn.BCELoss()(X_sample, X)

    recon_loss.backward()
    P_solver.step()
    Q_solver.step()
    reset_grad()

    """ Regularization phase """
    # Discriminator
    z_real = Variable(torch.randn(mb_size, z_dim))
    z_fake = Q(X)

    D_real = D(z_real)
    D_fake = D(z_fake)

    D_loss = -torch.mean(torch.log(D_real) + torch.log(1 - D_fake))

    D_loss.backward()
    D_solver.step()
    reset_grad()

    # Generator
    z_fake = Q(X)
    D_fake = D(z_fake)

    G_loss = -torch.mean(torch.log(D_fake))

    G_loss.backward()
    Q_solver.step()
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; D_loss: {:.4}; G_loss: {:.4}; recon_loss: {:.4}'
              .format(it, D_loss.item(), G_loss.item(), recon_loss.item()))

        samples = P(z_real).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'
                    .format(str(cnt).zfill(3)), bbox_inches='tight')
        cnt += 1
        plt.close(fig)

Iter-0; D_loss: 1.417; G_loss: 0.6542; recon_loss: 0.6952
Iter-1000; D_loss: 1.423; G_loss: 0.7065; recon_loss: 0.2693
Iter-2000; D_loss: 1.5; G_loss: 0.5522; recon_loss: 0.2584
Iter-3000; D_loss: 1.388; G_loss: 0.7931; recon_loss: 0.2247
Iter-4000; D_loss: 1.439; G_loss: 0.6323; recon_loss: 0.1967
Iter-5000; D_loss: 1.403; G_loss: 0.6942; recon_loss: 0.1886
Iter-6000; D_loss: 1.379; G_loss: 0.7051; recon_loss: 0.188
Iter-7000; D_loss: 1.384; G_loss: 0.6985; recon_loss: 0.1765
Iter-8000; D_loss: 1.396; G_loss: 0.696; recon_loss: 0.1799
Iter-9000; D_loss: 1.394; G_loss: 0.6977; recon_loss: 0.1732
