<a href="https://colab.research.google.com/github/HannahSu6/Condition-InfoGAN-Mutual-Information/blob/main/CGAN_with_Mutual_Information_Maximization_Qin.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Conditional Generative Adversarial Network (CGAN) with Mutual Information Maximization
This notebook demonstrates the implementation of a CGAN that not only generates images conditioned on specific classes but also incorporates mutual information maximization to improve the quality and diversity of the generated images. Mutual information ensures that the generator maintains relevant features from the input conditions, thus improving the stability and performance of the network.

This is a torch demostration code for explaining the Conditional GAN Network integrating mutual information on the MNIST dataset.

In [None]:
import argparse
import os
import numpy as np
import math
import itertools

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch
import matplotlib.pyplot as plt

  warn(


## Directory Setup for Output Images
Directories are set up to store images generated during the training process. This setup helps in visually inspecting the performance of the generator at various stages of training.


In [None]:
os.makedirs("images/static/", exist_ok=True)
os.makedirs("images/varying_c1/", exist_ok=True)
os.makedirs("images/varying_c2/", exist_ok=True)

## Simulation of Command Line Arguments
For the ease of running this notebook interactively, command line arguments are simulated within the notebook. These arguments help in configuring major parameters such as the number of training epochs, batch size, learning rate, etc., typically passed when executing a script from a command line.


In [None]:
import argparse
import sys

# Simulate command line arguments
sys.argv = ['ipykernel_launcher.py', '--n_epochs', '20', '--batch_size', '64']

# Create parser and add arguments
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=62, help="dimensionality of the latent space")
parser.add_argument("--code_dim", type=int, default=2, help="latent code")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")

# Parse arguments
opt = parser.parse_args()

# Print the parsed options
print(opt)


Namespace(n_epochs=20, batch_size=64, lr=0.0002, b1=0.5, b2=0.999, n_cpu=8, latent_dim=62, code_dim=2, n_classes=10, img_size=32, channels=1, sample_interval=400)


## Weight Initialization
Proper weight initialization can significantly affect the training dynamics and stability of GANs. Here, weights are initialized from a normal distribution, which is a common practice for GANs.


In [None]:
cuda = True if torch.cuda.is_available() else False


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


def to_categorical(y, num_columns):
    """Returns one-hot encoded Variable"""
    y_cat = np.zeros((y.shape[0], num_columns))
    y_cat[range(y.shape[0]), y] = 1.0

    return Variable(FloatTensor(y_cat))


## Network Architectures
### Generator Architecture
The generator's role is to synthesize plausible images from random noise conditioned on class labels and additional latent codes aimed at capturing more complex variations.

In the `Generator` class of the CGAN architecture, the input consisting of a noise vector, class labels, and latent codes is first transformed by a linear layer into a high-dimensional feature map. This feature map is reshaped into a three-dimensional volume with dimensions `(batch size, 128, init_size, init_size)`, where `init_size` is a quarter of the target image size (`opt.img_size // 4`). The reshaped output then passes through convolutional blocks that include upsampling steps to double the feature map dimensions twice and convolutional layers to refine features. These steps culminate in a final output with dimensions `(batch size, opt.channels, opt.img_size, opt.img_size)`, where `opt.channels` indicates the number of image color channels. This process efficiently structures the generator to transform the input dimensions into a detailed and appropriately sized image output.


In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        input_dim = opt.latent_dim + opt.n_classes + opt.code_dim
        # noise_vector + one-hot encoding vector + 2

        self.init_size = opt.img_size // 4  # Initial size before upsampling
        self.l1 = nn.Sequential(nn.Linear(input_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise, labels, code):
        gen_input = torch.cat((noise, labels, code), -1)
        out = self.l1(gen_input)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img




### Discriminator Architecture
The discriminator's job is to distinguish between real and generated images, providing feedback to the generator. It also helps in inferring the latent codes from the images, which is crucial for mutual information maximization.

The `Discriminator` class in the CGAN setup evaluates whether images are real or fake and identifies hidden details through layers that gradually reduce image size. It begins with the image's original channels and uses layers with increasing filters (16, 32, 64, 128) to decrease the image size by half each time. Each layer adds complexity using LeakyReLU for non-linear processing and dropout to prevent overfitting. After these layers, the image data is flattened, combined with label information, and refined through more linear layers. The final output includes a single value that tells if the image is real or fake and a set of values representing hidden codes. This allows the discriminator not only to check image authenticity but also to extract and use hidden data, improving the overall CGAN performance.

In [None]:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            """Returns layers of each discriminator block"""
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.conv_blocks = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 4

        self.process = nn.Sequential(
            nn.Linear(128 * ds_size ** 2+10, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 48),
            nn.LeakyReLU(0.2, inplace=True),
            )


        # Output layers
        self.adv_layer = nn.Sequential(nn.Linear(48, 1))
        self.latent_layer = nn.Sequential(nn.Linear(48, opt.code_dim))

    def forward(self, img,label):
        # print(img.shape)
        out = self.conv_blocks(img)
        # print(out.shape)
        out = out.view(out.shape[0], -1)
        out = torch.cat([label,out],1)
        out = self.process(out)
        # print(out.shape)
        validity = self.adv_layer(out)
        latent_code = self.latent_layer(out)

        return validity, latent_code

In [None]:
# Loss functions
adversarial_loss = torch.nn.MSELoss()
continuous_loss = torch.nn.MSELoss()

# Loss weights
lambda_con = 0.1

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    continuous_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)


## Data Preparation
The MNIST dataset is loaded and transformed to fit the model requirements, such as resizing images and normalizing them. A DataLoader is then configured to automatically batch and shuffle the data for efficient training.


In [None]:

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)


## Setup of Optimizers
Separate optimizers for the generator and discriminator are defined using the Adam optimizer, known for its efficiency in handling sparse gradients and adaptive learning rates.


In [None]:

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_info = torch.optim.Adam(
    itertools.chain(generator.parameters(), discriminator.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

# Static generator inputs for sampling
static_z = Variable(FloatTensor(np.zeros((opt.n_classes ** 2, opt.latent_dim))))
static_label = to_categorical(
    np.array([num for _ in range(opt.n_classes) for num in range(opt.n_classes)]), num_columns=opt.n_classes
)
static_code = Variable(FloatTensor(np.zeros((opt.n_classes ** 2, opt.code_dim))))



## Function for Sampling Images
To monitor the progress of our model, we periodically save generated images. This function generates a fixed grid of images to track how the image quality evolves as training progresses.


In [None]:

def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Static sample
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
    static_sample = generator(z, static_label, static_code)
    save_image(static_sample.data, "images/static/%d.png" % batches_done, nrow=n_row, normalize=True)

    # Get varied c1 and c2
    zeros = np.zeros((n_row ** 2, 1))
    c_varied = np.repeat(np.linspace(-1, 1, n_row)[:, np.newaxis], n_row, 0)
    c1 = Variable(FloatTensor(np.concatenate((c_varied, zeros), -1)))
    c2 = Variable(FloatTensor(np.concatenate((zeros, c_varied), -1)))
    sample1 = generator(static_z, static_label, c1)
    sample2 = generator(static_z, static_label, c2)
    save_image(sample1.data, "images/varying_c1/%d.png" % batches_done, nrow=n_row, normalize=True)
    save_image(sample2.data, "images/varying_c2/%d.png" % batches_done, nrow=n_row, normalize=True)



## Training Process
This section outlines the training loop, which alternates between updating the discriminator and the generator. Each pass includes computing loss functions that guide the model weights' updates to improve both the realism of generated images and their alignment with the input conditions.


In [None]:

# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        # Adversarial ground truths
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(FloatTensor))
        labels = to_categorical(labels.numpy(), num_columns=opt.n_classes)

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
        label_input = to_categorical(np.random.randint(0, opt.n_classes, batch_size), num_columns=opt.n_classes)
        code_input = Variable(FloatTensor(np.random.uniform(-1, 1, (batch_size, opt.code_dim))))

        # Generate a batch of images
        gen_imgs = generator(z, label_input, code_input)

        # Loss measures generator's ability to fool the discriminator
        validity, _ = discriminator(gen_imgs,label_input)
        g_loss = adversarial_loss(validity, valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss for real images
        real_pred, _  = discriminator(real_imgs,labels)
        d_real_loss = adversarial_loss(real_pred, valid)

        # Loss for fake images
        fake_pred, _ = discriminator(gen_imgs.detach(),label_input)
        d_fake_loss = adversarial_loss(fake_pred, fake)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        # ------------------
        # Information Loss
        # ------------------

        optimizer_info.zero_grad()

        # Sample labels
        sampled_labels = np.random.randint(0, opt.n_classes, batch_size)

        # Ground truth labels
        gt_labels = Variable(LongTensor(sampled_labels), requires_grad=False)

        # Sample noise, labels and code as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
        label_input = to_categorical(sampled_labels, num_columns=opt.n_classes)
        code_input = Variable(FloatTensor(np.random.uniform(-1, 1, (batch_size, opt.code_dim))))

        gen_imgs = generator(z, label_input, code_input)
        _, pred_code = discriminator(gen_imgs, label_input)

        info_loss = lambda_con * continuous_loss(pred_code, code_input)

        info_loss.backward()
        optimizer_info.step()

        # --------------
        # Log Progress
        # --------------

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [info loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), info_loss.item())
        )
        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            sample_image(n_row=10, batches_done=batches_done)

  static_z = Variable(FloatTensor(np.zeros((opt.n_classes ** 2, opt.latent_dim))))


[Epoch 0/20] [Batch 0/938] [D loss: 0.405005] [G loss: 0.798505] [info loss: 0.032979]
[Epoch 0/20] [Batch 1/938] [D loss: 0.402233] [G loss: 0.792346] [info loss: 0.035816]
[Epoch 0/20] [Batch 2/938] [D loss: 0.400214] [G loss: 0.786637] [info loss: 0.030509]
[Epoch 0/20] [Batch 3/938] [D loss: 0.398042] [G loss: 0.781814] [info loss: 0.032879]
[Epoch 0/20] [Batch 4/938] [D loss: 0.395671] [G loss: 0.775972] [info loss: 0.037234]
[Epoch 0/20] [Batch 5/938] [D loss: 0.393604] [G loss: 0.772419] [info loss: 0.026617]
[Epoch 0/20] [Batch 6/938] [D loss: 0.391137] [G loss: 0.766846] [info loss: 0.035253]
[Epoch 0/20] [Batch 7/938] [D loss: 0.388643] [G loss: 0.758878] [info loss: 0.032524]
[Epoch 0/20] [Batch 8/938] [D loss: 0.385617] [G loss: 0.755485] [info loss: 0.036618]
[Epoch 0/20] [Batch 9/938] [D loss: 0.384629] [G loss: 0.750717] [info loss: 0.032853]
[Epoch 0/20] [Batch 10/938] [D loss: 0.380965] [G loss: 0.743705] [info loss: 0.027528]
[Epoch 0/20] [Batch 11/938] [D loss: 0.378

## Conclusion and Observations
At the end of the training, generated images are examined to evaluate the model's performance. Observations regarding the model's ability to generate diverse and realistic images, as well as its stability over training epochs, are discussed. The inclusion of mutual information has been hypothesized to enhance both the diversity and fidelity of the generated samples.
