In [1]:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
    

In [2]:

# Generator
class Generator(nn.Module):
    def __init__(self, z_dim, c_dim, output_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(z_dim + c_dim, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc4 = nn.Linear(1024, output_dim)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, z, c):
        x = torch.cat([z, c], dim=1)  # Concatenate noise and latent code
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.tanh(self.fc4(x))
        return x
    

In [3]:

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, input_dim, c_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(input_dim, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, 1)
        self.fc5 = nn.Linear(256, c_dim)  # For predicting latent code
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        validity = self.sigmoid(self.fc4(x))  # Real vs Fake
        c_pred = self.fc5(x)  # Predicted latent code
        return validity, c_pred
    

In [4]:

# Auxiliary Classifier
class AuxiliaryClassifier(nn.Module):
    def __init__(self, input_dim, c_dim):
        super(AuxiliaryClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, c_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        c_pred = self.fc4(x)  # Predicted latent code
        return c_pred
    

In [5]:

# Loss Functions
def generator_loss(D, G, z, c, lambda_):
    # Generator tries to fool the discriminator
    fake_images = G(z, c)
    validity, _ = D(fake_images)
    
    # InfoGAN regularization term: maximizing mutual information
    c_pred = Q(fake_images)
    info_loss = torch.mean((c_pred - c) ** 2)
    
    # GAN loss (adversarial loss)
    g_loss = -torch.mean(torch.log(validity))
    
    # Combine GAN loss and info loss
    total_loss = g_loss + lambda_ * info_loss
    return total_loss

def discriminator_loss(D, G, real_images, fake_images, c, Q):
    # Discriminator tries to classify real vs fake and predict latent code
    validity_real, c_pred_real = D(real_images)
    validity_fake, c_pred_fake = D(fake_images)
    
    # Discriminator loss (real vs fake)
    d_loss_real = -torch.mean(torch.log(validity_real))
    d_loss_fake = -torch.mean(torch.log(1 - validity_fake))
    
    # Latent code prediction loss
    info_loss = torch.mean((c_pred_fake - c) ** 2)
    
    # Total loss
    d_loss = d_loss_real + d_loss_fake + info_loss
    return d_loss
    

In [None]:

# Hyperparameters
z_dim = 100  # Random noise dimension
c_dim = 10   # Latent code dimension (for MNIST, it can be 10)
lr = 0.0002  # Learning rate
batch_size = 64
lambda_ = 0.1  # Regularization factor for InfoGAN

# DataLoader for MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

# Models and optimizers
G = Generator(z_dim, c_dim, 28*28)
D = Discriminator(28*28, c_dim)
Q = AuxiliaryClassifier(28*28, c_dim)

optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_Q = optim.Adam(Q.parameters(), lr=lr, betas=(0.5, 0.999))

# Training loop
for epoch in range(50):  # Number of epochs
    for real_images, _ in train_loader:
        real_images = real_images.view(-1, 28*28)  # Flatten MNIST images
        
        # Create random latent codes and noise
        z = torch.randn(batch_size, z_dim)
        c = torch.randint(0, c_dim, (batch_size,)).long()  # Random latent code
        c_one_hot = torch.zeros(batch_size, c_dim).scatter_(1, c.unsqueeze(1), 1)

        # Train discriminator
        fake_images = G(z, c_one_hot)
        d_loss = discriminator_loss(D, G, real_images, fake_images, c_one_hot, Q)

        optimizer_D.zero_grad()
        optimizer_Q.zero_grad()
        d_loss.backward()
        optimizer_D.step()
        optimizer_Q.step()

        # Train generator
        g_loss = generator_loss(D, G, z, c_one_hot, lambda_)

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

    print(f"Epoch [{epoch+1}/50], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")
    

Epoch [1/50], D Loss: 0.5554754137992859, G Loss: 4.990442276000977
Epoch [2/50], D Loss: 0.2226797342300415, G Loss: 4.185049057006836
Epoch [3/50], D Loss: 0.4260225296020508, G Loss: 2.8924405574798584
Epoch [4/50], D Loss: 0.18990181386470795, G Loss: 3.4604358673095703
Epoch [5/50], D Loss: 0.3381187617778778, G Loss: 3.182933807373047
Epoch [6/50], D Loss: 0.788844108581543, G Loss: 2.6332483291625977
Epoch [7/50], D Loss: 0.17768675088882446, G Loss: 3.5525999069213867
Epoch [8/50], D Loss: 0.5918861627578735, G Loss: 2.6699342727661133
Epoch [9/50], D Loss: 0.8713591694831848, G Loss: 3.095942974090576
Epoch [10/50], D Loss: 0.8181343674659729, G Loss: 2.05318546295166
Epoch [11/50], D Loss: 0.8504294753074646, G Loss: 1.7079942226409912
Epoch [12/50], D Loss: 0.5771058201789856, G Loss: 1.9396847486495972
Epoch [13/50], D Loss: 0.7209910154342651, G Loss: 2.588313341140747
Epoch [14/50], D Loss: 0.5898507237434387, G Loss: 2.022254228591919
Epoch [15/50], D Loss: 0.75183248519

In [None]:

# Function to display generated images
def show_images(images, ncols=8):
    images = images.view(-1, 1, 28, 28)  # Reshape to image format
    grid = torchvision.utils.make_grid(images, nrow=ncols)
    plt.imshow(grid.permute(1, 2, 0).detach().numpy())
    plt.show()

# Show some generated images after training
z = torch.randn(64, z_dim)
c = torch.randint(0, c_dim, (64,)).long()  # Random latent codes
c_one_hot = torch.zeros(64, c_dim).scatter_(1, c.unsqueeze(1), 1)
generated_images = G(z, c_one_hot)
show_images(generated_images)
    