<a href="https://colab.research.google.com/github/Muhammad-Ikhwan-Fathulloh/Generative-Adversarial-Network-GAN/blob/main/avb_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from torchvision.datasets import MNIST
from torchvision import transforms

# Define data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST dataset
mnist = MNIST(root='./data', train=True, transform=transform, download=True)

# Normalize the data to [0, 1] range
mnist.data = mnist.data.float() / 255.0

# Parameters
mb_size = 32
z_dim = 5
X_dim = mnist.data.size(1) * mnist.data.size(2)  # Flattened image dimensions
h_dim = 128
lr = 1e-3

# Create noise dimension
eps_dim = 10  # Dimension of the noise vector

# Encoder: q(z|x,eps)
Q = torch.nn.Sequential(
    torch.nn.Linear(X_dim + eps_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, z_dim)
)

# Decoder: p(x|z)
P = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),
    torch.nn.Sigmoid()
)

# Discriminator: T(X, z)
T = torch.nn.Sequential(
    torch.nn.Linear(X_dim + z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, 1)
)

def reset_grad():
    Q.zero_grad()
    P.zero_grad()
    T.zero_grad()


def sample_X(size):
    indices = np.random.randint(0, len(mnist), size)
    X = mnist.data[indices].view(size, -1).float()
    return Variable(X)


# Optimizers
Q_solver = optim.Adam(Q.parameters(), lr=lr)
P_solver = optim.Adam(P.parameters(), lr=lr)
T_solver = optim.Adam(T.parameters(), lr=lr)

# Initialize counter
cnt = 0
"""1000000"""

# Your training loop goes here
for it in range(100000):
    X = sample_X(mb_size)
    eps = Variable(torch.randn(mb_size, eps_dim))
    z = Variable(torch.randn(mb_size, z_dim))

    # Optimize VAE
    z_sample = Q(torch.cat([X, eps], 1))
    X_sample = P(z_sample)
    T_sample = T(torch.cat([X, z_sample], 1))

    disc = torch.mean(-T_sample)
    loglike = -nn.BCELoss()(X_sample, X) # Use BCELoss for binary cross-entropy

    elbo = -(disc + loglike)

    elbo.backward()
    Q_solver.step()
    P_solver.step()
    reset_grad()

    # Discriminator T(X, z)
    z_sample = Q(torch.cat([X, eps], 1))
    T_q = nn.Sigmoid()(T(torch.cat([X, z_sample], 1)))
    T_prior = nn.Sigmoid()(T(torch.cat([X, z], 1)))

    T_loss = -torch.mean(torch.log(T_q) + torch.log(1. - T_prior))

    T_loss.backward()
    T_solver.step()
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; ELBO: {:.4}; T_loss: {:.4}'
              .format(it, -elbo.item(), -T_loss.item()))

        samples = P(z).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'
                    .format(str(cnt).zfill(3)), bbox_inches='tight')
        cnt += 1
        plt.close(fig)

Iter-0; ELBO: -0.6802; T_loss: -1.388
Iter-1000; ELBO: 1.359; T_loss: -4.232
Iter-2000; ELBO: 0.5327; T_loss: -2.327
Iter-3000; ELBO: -0.2026; T_loss: -1.408
Iter-4000; ELBO: -0.3201; T_loss: -1.356
Iter-5000; ELBO: -0.6626; T_loss: -1.146
Iter-6000; ELBO: -0.2099; T_loss: -1.627
Iter-7000; ELBO: -1.105; T_loss: -0.8719
Iter-8000; ELBO: -0.02747; T_loss: -1.683
Iter-9000; ELBO: -0.3182; T_loss: -1.393
Iter-10000; ELBO: -0.4565; T_loss: -1.299
Iter-11000; ELBO: -11.99; T_loss: -0.001405
Iter-12000; ELBO: -0.1754; T_loss: -1.362
Iter-13000; ELBO: -0.3587; T_loss: -1.299
Iter-14000; ELBO: -0.2247; T_loss: -1.367
Iter-15000; ELBO: -0.7013; T_loss: -1.154
Iter-16000; ELBO: -0.3266; T_loss: -1.399
Iter-17000; ELBO: -0.3998; T_loss: -1.488
Iter-18000; ELBO: -0.5714; T_loss: -1.162
Iter-19000; ELBO: -0.4501; T_loss: -1.206
Iter-20000; ELBO: -0.6838; T_loss: -1.192
Iter-21000; ELBO: -0.5736; T_loss: -1.262
Iter-22000; ELBO: -0.8184; T_loss: -1.164
Iter-23000; ELBO: -1.064; T_loss: -0.8597
Iter-