In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

In [15]:
# Loading the dataset
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5,))]
)
dataset = datasets.FashionMNIST('./', train=True, download=True, transform=transform)
data = DataLoader(dataset, batch_size=64, shuffle=False)

In [16]:
# generator network
class Generator(nn.Module):
    def __init__(self, noise_dim, img_dim):
        super().__init__()
        
        self.g_theta = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 512),
            nn.LeakyReLU(),
            nn.BatchNorm1d(512),
            nn.Linear(512, img_dim),
            nn.Tanh()
        )
        
    def forward(self, z):
        return self.g_theta(z)
        

In [17]:
class Discriminator(nn.Module):
    def __init__(self, img_dim, noise_dim):
        super().__init__()
        
        self.D_w = nn.Sequential(
            nn.Linear(img_dim+noise_dim, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x, z):
        i = torch.concat([x, z], dim=1)
        return self.D_w(i)

In [18]:
class Encoder(nn.Module):
    def __init__(self, img_dim, noise_dim):
        super().__init__()
        
        self.E_phi = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256, noise_dim)
        )
    
    def forward(self, x):
        return self.E_phi(x)


In [19]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [20]:
random_noise_dim = 150
img_dim = 784

In [21]:
generator = Generator(noise_dim= random_noise_dim, img_dim= img_dim).to(device)
discriminator = Discriminator(img_dim, random_noise_dim).to(device)
encoder = Encoder(img_dim, random_noise_dim).to(device)
generator, discriminator, encoder

(Generator(
   (g_theta): Sequential(
     (0): Linear(in_features=150, out_features=256, bias=True)
     (1): LeakyReLU(negative_slope=0.01)
     (2): Linear(in_features=256, out_features=512, bias=True)
     (3): LeakyReLU(negative_slope=0.01)
     (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (5): Linear(in_features=512, out_features=784, bias=True)
     (6): Tanh()
   )
 ),
 Discriminator(
   (D_w): Sequential(
     (0): Linear(in_features=934, out_features=512, bias=True)
     (1): LeakyReLU(negative_slope=0.01)
     (2): Linear(in_features=512, out_features=256, bias=True)
     (3): LeakyReLU(negative_slope=0.01)
     (4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (5): Linear(in_features=256, out_features=1, bias=True)
     (6): Sigmoid()
   )
 ),
 Encoder(
   (E_phi): Sequential(
     (0): Linear(in_features=784, out_features=512, bias=True)
     (1): LeakyReLU(negative_slope=0.01)
     (2):

In [22]:
optimizer_d = optim.Adam(discriminator.parameters(), lr=2e-4)
optimizer_ge = optim.Adam(list(generator.parameters()) + list(encoder.parameters()), lr=2e-4)

criterion = nn.BCELoss()

In [23]:

def show_generated_images(epoch, generator, fixed_noise):
    generator.eval()
    with torch.no_grad():
        fake_imgs = generator(fixed_noise).reshape(-1, 1, 28, 28)
        fake_imgs = fake_imgs * 0.5 + 0.5  # De-normalize

    grid = torchvision.utils.make_grid(fake_imgs, nrow=8)
    plt.figure(figsize=(8,8))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.title(f'Generated Images at Epoch {epoch}')
    plt.axis('off')
    plt.show()
    generator.train()

In [24]:
fixed_random_samples = torch.randn(64, random_noise_dim).to(device) # samples to check the output generated by generator across epochs.

In [None]:
epochs = 100
for epoch in range(epochs):
    for idx, (real_data, label) in enumerate(data):
        
        batch_size = real_data.shape[0]
        real_data = real_data.view(batch_size, -1).to(device)
        
        # ones = torch.ones(batch_size, 1).to(device)
        # zeros = torch.zeros(batch_size, 1).to(device)
        
        # Training the discriminator
        generator.eval()
        encoder.eval()
        discriminator.train()
        
        z_sampled = torch.randn(batch_size, random_noise_dim).to(device)
        z_hat = encoder(real_data)
        x_hat = generator(z_sampled)
        
        D_real = discriminator(real_data, z_hat)
        D_generated = discriminator(x_hat, z_sampled)
        
        ones = torch.ones_like(D_real)
        zeros = torch.zeros_like(D_generated)
        
        disc_loss = criterion(D_real, ones) + criterion(D_generated, zeros)
        
        optimizer_d.zero_grad()
        disc_loss.backward()
        optimizer_d.step()

        # Training the generator and encoder
        generator.train()
        encoder.train()
        discriminator.eval()
        
        for _ in range(5):
            z_sampled = torch.randn(batch_size, random_noise_dim).to(device)
            z_hat = encoder(real_data)
            x_hat = generator(z_sampled)
            
            D_real = discriminator(real_data, z_hat)
            D_generated = discriminator(x_hat, z_sampled)
            
            ge_loss = criterion(D_real, zeros) + criterion(D_generated, ones)
            
            optimizer_ge.zero_grad()
            ge_loss.backward()
            optimizer_ge.step()
            
    if (epoch+1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], D_loss: {disc_loss.item():.4f}, G_loss: {ge_loss.item():.4f}")
            show_generated_images(epoch+1, generator, fixed_random_samples)
    else:
        print(f"Epoch [{epoch+1}/{epochs}], D_loss: {disc_loss.item():.4f}, G_loss: {ge_loss.item():.4f}")
    

Epoch [1/100], D_loss: 1.3867, G_loss: 0.8368
Epoch [2/100], D_loss: 1.3866, G_loss: 0.7476
Epoch [3/100], D_loss: 1.3864, G_loss: 0.9393
Epoch [4/100], D_loss: 1.3865, G_loss: 1.1876
Epoch [5/100], D_loss: 1.3863, G_loss: 1.1943
Epoch [6/100], D_loss: 1.3866, G_loss: 1.0886
Epoch [7/100], D_loss: 1.3864, G_loss: 0.6906
Epoch [8/100], D_loss: 1.3863, G_loss: 1.3008
