In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

import torch.nn.functional as F
torch.manual_seed(0) 

<torch._C.Generator at 0x7fd22ed13a50>

In [2]:

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28), nrow=5, show=True):
    image_tensor = (image_tensor + 1) / 2# rescales to from the range [-1, 1] to [0, 1]
    image_unflat = image_tensor.detach().cpu() # detached from the computation graph
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze()) # squeeze to remove any singleton dim
    if show:
        plt.show()

In [3]:
class Generator(nn.Module):
    def __init__(self, input_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.input_dim = input_dim
        self.gen = nn.Sequential(
            self.make_gen_block(input_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                # increasing the spatial resolution
                # deconvolution, upsampling
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride), # for generating images with higgher dimension
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True), # used only in non-final layers, helps the network lean complex patterns in the data. introduce non-linearity
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(), # range [-1, 1] - matches the typical range of real image pixel values normalized to this range.
            )

    def forward(self, noise): # len(noise) - batch size
        x = noise.view(len(noise), self.input_dim, 1, 1) # (1, 1) - row and col
        return self.gen(x) 

def get_noise(n_samples, input_dim, device='cpu'):
    return torch.randn(n_samples, input_dim, device=device)

In [4]:
class Discriminator(nn.Module):
    def __init__(self, im_chan=1, hidden_dim=64):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim),
            self.make_disc_block(hidden_dim, hidden_dim * 2),
            self.make_disc_block(hidden_dim * 2, 1, final_layer=True),
        )

    def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True), # to learn both positive and negtive values in the feature map
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride), # it is binray classification, only one output value, so no activations.
            )

    def forward(self, image):
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

In [5]:
# getting one-hot vector for the labels - conditional information
def get_one_hot_labels(labels, n_classes):
    return F.one_hot(labels, num_classes=n_classes) # returns (?, num_classes)

In [6]:
# combining two vectors, noise vecotr and the on-hot vector

def combine_vectors(x, y):
    combined = torch.cat(tensors=[x, y], dim=1).float()
    return combined

In [7]:
mnist_shape = (1, 28, 28)
n_classes = 10

criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.0002
device = 'cuda'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    MNIST('../W1 Intro to GAN/', download=False, transform=transform),
    batch_size=batch_size,
    shuffle=True)

In [8]:
mnist_shape

(1, 28, 28)

In [9]:
# for getting the size of the conditional input dimension
def get_input_dimensions(z_dim, mnist_shape, n_classes):
    generator_input_dim = z_dim + n_classes

    # Calculate the number of input channels for the discriminator
    discriminator_im_chan = mnist_shape[0] + n_classes # channels in MNIST images with n_classes

    return generator_input_dim, discriminator_im_chan

In [10]:
generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, mnist_shape, n_classes)

gen = Generator(input_dim=generator_input_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)

disc = Discriminator(im_chan=discriminator_im_chan).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): # for convolution layers
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d): # for batchnorm layers
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
        
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

In [11]:
'''
In cGANs, control over the gen imgs is done by providing conditional information to both the generator and discriminator/ 

Gen takes both random noise and conditionla information as input.
Disc takes both real images and generated images along with the corresponding conditional information
'''

'\nIn cGANs, control over the gen imgs is done by providing conditional information to both the generator and discriminator/ \n\nGen takes both random noise and conditionla information as input.\nDisc takes both real images and generated images along with the corresponding conditional information\n'

In [12]:
def show_real_and_fake_images(real_images, fake_images, num_images=5, size=(1, 28, 28), nrow=5):
    """
    Display real and fake images side by side.

    Args:
        real_images (torch.Tensor): A tensor containing real images.
        fake_images (torch.Tensor): A tensor containing fake/generated images.
        num_images (int): The number of images to display for each type (real and fake).
        size (tuple): The size of the images (e.g., (1, 28, 28) for grayscale MNIST).
        nrow (int): Number of images to display in each row.

    """
    real_images = (real_images + 1) / 2  # Rescale from [-1, 1] to [0, 1]
    fake_images = (fake_images + 1) / 2

    # Detach from the computation graph and move to CPU for display
    real_images_unflat = real_images.detach().cpu()
    fake_images_unflat = fake_images.detach().cpu()

    # Create image grids for real and fake images
    real_image_grid = make_grid(real_images_unflat[:num_images], nrow=nrow)
    fake_image_grid = make_grid(fake_images_unflat[:num_images], nrow=nrow)

    # Plot real and fake images side by side
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(real_image_grid.permute(1, 2, 0).squeeze(), cmap='gray')
    axes[0].set_title("Real Images")
    axes[0].axis('off')
    axes[1].imshow(fake_image_grid.permute(1, 2, 0).squeeze(), cmap='gray')
    axes[1].set_title("Fake Images")
    axes[1].axis('off')
    plt.tight_layout()
    plt.show()


In [None]:
cur_step = 0

generator_losses = []
discriminator_losses = []

noise_and_labels = False
fake = False

fake_image_and_labels = False
real_image_and_labels = False

disc_fake_pred = False
disc_real_pred = False

for epoch in range(n_epochs):
    for real, labels in tqdm(dataloader):
        cur_batch_size = len(real) # no of images in the curr batch
        
        real = real.to(device)
        
        one_hot_labels = get_one_hot_labels(labels.to(device), n_classes) # move tensor to the device
        # (?, num_classes)
        image_one_hot_labels = one_hot_labels[:, :, None, None] 
        image_one_hot_labels = image_one_hot_labels.repeat(1, 1, mnist_shape[1],mnist_shape[2]) # reshapes one-hot labels to match the shape of the images in the dataset. it repeats the on-hot labels across the imaeg dimensions
        
        ### Update the discriminator
        disc_opt.zero_grad() # initilize the grads of disc
        fake_noise = get_noise(cur_batch_size, z_dim, device) #gens random noise of batch size
        
        # Combine the noise vectors and the one-hot labels for the generator
        noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
        
        fake = gen(noise_and_labels) # generates fake images
        # Generate the conditional Fake images

        '''
        - Create the input for the disc
        - get the disc pred on the fakes as disc_fake_pred
        - get the disc pred on the reals as disc_fake_pred
        '''
        # prepare input for the disc
        fake_image_and_labels = combine_vectors(fake, image_one_hot_labels) # combine fake with labels
        real_image_and_labels = combine_vectors(real, image_one_hot_labels) # combine real with labels
        
        '''Training with real and fake images'''
        # disc predictions fr the fake and real images
        disc_fake_pred = disc(fake_image_and_labels.detach())
        disc_real_pred = disc(real_image_and_labels)
        
        
        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
        
        disc_loss = (disc_fake_loss + disc_real_loss) / 2 # average loss of these two losses
        
        disc_loss.backward(retain_graph = True) # update the disc weights
        disc_opt.step()
        
        discriminator_losses += [disc_loss.item()]
        
        
        ### Update the Generator
        gen_opt.zero_grad()
        

        fake_image_and_labels = combine_vectors(fake, image_one_hot_labels) # generated image and it's real lables
        disc_fake_pred = disc(fake_image_and_labels) # value of the prediction of disc
        
        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
        
        gen_loss.backward()
        gen_opt.step()
        
        generator_losses += [gen_loss.item()]
        
        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            disc_mean = sum(discriminator_losses[-display_step:]) / display_step
            
            print(f"Step {cur_step}: Generator loss: {gen_mean}, discriminator loss: {disc_mean}")

            show_tensor_images(fake)
            show_tensor_images(real)
            
            step_bins = 20
            
            x_axis = sorted([i * step_bins for i in range(len(generator_losses) // step_bins)] * step_bins)
            num_examples = (len(generator_losses) // step_bins) * step_bins
            
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label = "Generated Loss"
            )
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label = "Discriminator Loss"
            )
            plt.legend()
            plt.show()
        elif cur_step == 0:
            print("Training Started......")
        
        cur_step += 1
            

In [38]:
gen = Generator()
gen

Generator(
  (gen): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(10, 256, kernel_size=(3, 3), stride=(2, 2))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (3): Sequential(
      (0): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2))
      (1): Tanh()
    )
  )
)

In [39]:
disc = Discriminator()
disc

Discriminator(
  (disc): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(128, 1, kernel_size=(4, 4), stride=(2, 2))
    )
  )
)

In [40]:
gen = gen.eval()

In [41]:
gen

Generator(
  (gen): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(10, 256, kernel_size=(3, 3), stride=(2, 2))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (3): Sequential(
      (0): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2))
      (1): Tanh()
    )
  )
)

In [42]:
torch.__version__

'2.0.1+cu117'

In [None]:
import math

n_interpolation = 9 # no of intermediate images we want + 2 (for the start and end image)
interpolation_noise = get_noise(1, z_dim, device=device).repeat(n_interpolation, 1)

def interpolate_class(first_number, second_number):
    first_label = get_one_hot_labels(torch.Tensor([first_number]).long(), n_classes)
    second_label = get_one_hot_labels(torch.Tensor([second_number]).long(), n_classes)

    # Calculate the interpolation vector between the two labels
    percent_second_label = torch.linspace(0, 1, n_interpolation)[:, None]
    interpolation_labels = first_label * (1 - percent_second_label) + second_label * percent_second_label

    # Combine the noise and the labels
    noise_and_labels = combine_vectors(interpolation_noise, interpolation_labels.to(device))
    fake = gen(noise_and_labels)
    show_tensor_images(fake, num_images=n_interpolation, nrow=int(math.sqrt(n_interpolation)), show=False)

start_plot_number = 1 #  start digit

end_plot_number = 5 #   end digit

plt.figure(figsize=(8, 8))
interpolate_class(start_plot_number, end_plot_number)
_ = plt.axis('off')


plot_numbers = [2, 3, 4, 5, 7]
n_numbers = len(plot_numbers)
plt.figure(figsize=(8, 8))
for i, first_plot_number in enumerate(plot_numbers):
    for j, second_plot_number in enumerate(plot_numbers):
        plt.subplot(n_numbers, n_numbers, i * n_numbers + j + 1)
        interpolate_class(first_plot_number, second_plot_number)
        plt.axis('off')
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0.1, wspace=0)
plt.show()
plt.close()

In [None]:
n_interpolation = 9 # How many intermediate images you want + 2 (for the start and end image)

# This time you're interpolating between the noise instead of the labels
interpolation_label = get_one_hot_labels(torch.Tensor([5]).long(), n_classes).repeat(n_interpolation, 1).float()

def interpolate_noise(first_noise, second_noise):
    # This time you're interpolating between the noise instead of the labels
    percent_first_noise = torch.linspace(0, 1, n_interpolation)[:, None].to(device)
    interpolation_noise = first_noise * percent_first_noise + second_noise * (1 - percent_first_noise)

    # Combine the noise and the labels again
    noise_and_labels = combine_vectors(interpolation_noise, interpolation_label.to(device))
    fake = gen(noise_and_labels)
    show_tensor_images(fake, num_images=n_interpolation, nrow=int(math.sqrt(n_interpolation)), show=False)

# Generate noise vectors to interpolate between
### Change me! ###
n_noise = 5 # Choose the number of noise examples in the grid
plot_noises = [get_noise(1, z_dim, device=device) for i in range(n_noise)]
plt.figure(figsize=(8, 8))
for i, first_plot_noise in enumerate(plot_noises):
    for j, second_plot_noise in enumerate(plot_noises):
        plt.subplot(n_noise, n_noise, i * n_noise + j + 1)
        interpolate_noise(first_plot_noise, second_plot_noise)
        plt.axis('off')
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0.1, wspace=0)
plt.show()
plt.close()