# Setup

In [1]:
import torch
import torchvision
from torchvision import datasets, transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
from tqdm import tqdm
from torch.utils.data.sampler import SequentialSampler

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.animation import FuncAnimation, PillowWriter
from IPython.display import HTML
%matplotlib inline

In [2]:
# Training data, you can choose MNIST or Fashion MNIST
DATASET_NAME = "MNIST" #@param ["MNIST", "Fashion MNIST"]
# Generator architecture, you can choose cDCGAN or HAe/DaES
ARCHITECHTURE = "HAe/DaES" #@param ["VAE", "cDCGAN", "VAE and cDCGAN", "HAe/DaES"]
# batch size for training models
BATCH_SIZE = 128 #@param {type:"integer"}
# batch size for testing models
TEST_BATCH_SIZE = 128 #@param {type:"integer"}
# number of samples to be taken from VAE latent space
NUM_SAMPLES = 10 #@param {type:"integer"}
# scaling factor for sample standard deviation
STD_SCALING = 7.5 #@param {type:"number"}
# number of epochs over which to train the models
NUM_EPOCHS = 200 #@param {type:"integer"}
# number of epochs over which to train the models
EPOCHS_BETWEEN_VAL = 100 #@param {type:"integer"}
# Whether or not to shuffle the training and testing data
# note: mutually exclusive with TRAIN_ON_SUBSET
SHUFFLE = False #@param {type:"boolean"}
#Whether or not to train on a subset of the MNIST training data
# note: mutually exclusive with SHUFFLE
TRAIN_ON_SUBSET = True #@param {type:"boolean"}
#Size of subset on which to train
SUBSET_SIZE = 600 #@param {type:"integer"}
#Whether or not to pretrain the models before intangling them
PRE_TRAIN = True #@param {type:"boolean"}
# number of epochs over which to train the models
NUM_PRETRAIN_EPOCHS = 200 #@param {type:"integer"}
# Set the device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {DEVICE} backend")

Using cuda backend


# Loading Data

In [3]:
# Data Preprocessing

transform = transforms.Compose([transforms.ToTensor()
                               ])

if DATASET_NAME == "MNIST":
    # download and load training set of MNISTtrain_data = datasets.MNIST("./mnist", train=True, download=True, transform = transform)
    train_data = datasets.MNIST("./mnist", train=True, download=True, transform = transform)

    # download and load test set of MNIST
    test_data = datasets.MNIST("./mnist", train=False, download=True, transform = transform)

elif DATASET_NAME == "Fashion MNIST":
    # download and load training set of Fashion MNIST
    train_data = datasets.FashionMNIST("./fmnist", train=True, download=True, transform = transform)

    # download and load test set of Fashion MNIST
    test_data = datasets.FashionMNIST("./fmnist", train=False, download=True, transform = transform)
else:
  ValueError(f"Please select valid dataset, {DATASET_NAME} is not supported")

if TRAIN_ON_SUBSET:
  # Generate a list of random indices that covers the entire dataset
  indices = torch.randperm(len(train_data))

  # Select the first `SUBSET_SIZE` indices
  selected_indices = indices[:SUBSET_SIZE]

  # Define a sequential sampler using the first `SUBSET_SIZE` indices
  sampler = SequentialSampler(selected_indices)
else:
  sampler = None

dataloader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=SHUFFLE, num_workers=2, drop_last=True, sampler=sampler)

testloader = torch.utils.data.DataLoader(test_data, batch_size=TEST_BATCH_SIZE, shuffle=SHUFFLE, num_workers=2, drop_last=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 165054373.69it/s]

Extracting ./mnist/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 15576146.82it/s]


Extracting ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 71770647.84it/s]


Extracting ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 18074505.47it/s]

Extracting ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw






# Variation Autoencoder Model

In [4]:
ENCODED_DIM = 16 #@param {type:"integer"}

In [5]:
# https://medium.com/dataseries/convolutional-autoencoder-in-pytorch-on-mnist-dataset-d65145c132ac
class Autoencoder(nn.Module):

    def __init__(self, encoded_space_dim):
        super().__init__()

        ### Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            nn.ReLU(True),
            nn.Flatten(start_dim=1),
            nn.Linear(3 * 3 * 32, 128),
            nn.ReLU(True),
            nn.Linear(128, 256)
        )
        self.fc_mu = nn.Linear(256, encoded_space_dim)
        self.fc_logvar = nn.Linear(256, encoded_space_dim)
        ### Decoder
        self.decoder = nn.Sequential(
            nn.Linear(encoded_space_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU(True),
            nn.Unflatten(dim=1, unflattened_size=(32, 3, 3)),
            nn.ConvTranspose2d(32, 16, 3, stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        """
        :param mu: mean from the encoder's latent space
        :param log_var: log variance from the encoder's latent space
        """
        std = torch.exp(0.5*logvar)
        z = torch.randn_like(std)
        sample = mu + (z * std)
        return sample

    def forward(self, x):
        encoded_x = self.encoder(x)
        mu = self.fc_mu(encoded_x)
        logvar = self.fc_logvar(encoded_x)
        sample = self.reparameterize(mu, logvar)
        reconstruction = self.decoder(sample)
        return reconstruction, sample, mu, logvar

In [6]:
AE = Autoencoder(ENCODED_DIM).to(DEVICE)
print(AE)
AE_total_params = sum(p.numel() for p in AE.parameters() if p.requires_grad)
print(f"Total Trainable Parameters: {AE_total_params}")

KeyboardInterrupt: ignored

# Discriminator Model

In [None]:
# https://github.com/drc10723/GAN_design/blob/master/GAN_implementations/Conditional_DCGAN_MNIST.ipynb

LABEL_SIZE = ENCODED_DIM


class Discriminator(nn.Module):
  """ D(x) """
  def __init__(self, label_size):
    # initalize super module
    super(Discriminator, self).__init__()

    # creating layer for image input , input size : (batch_size, 1, 28, 28)
    self.layer_x = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=32,
                                           kernel_size=4, stride=2, padding=1, bias=False),
                                 # out size : (batch_size, 32, 14, 14)
                                 nn.LeakyReLU(0.2, inplace=True),
                                 # out size : (batch_size, 32, 14, 14)
                                )

    # creating layer for label input, input size : (batch_size, label_size, 28, 28)
    self.layer_y = nn.Sequential(nn.Conv2d(in_channels=label_size, out_channels=32,
                                           kernel_size=4, stride=2, padding=1, bias=False),
                                 # out size : (batch_size, 32, 14, 14)
                                 nn.LeakyReLU(0.2, inplace=True),
                                 # out size : (batch_size, 32, 14, 14)
                                 )

    # layer for concat of image layer and label layer, input size : (batch_size, 64, 14, 14)
    self.layer_xy = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128,
                                            kernel_size=4, stride=2, padding=1, bias=False),
                               # out size : (batch_size, 128, 7, 7)
                               nn.BatchNorm2d(128),
                               # out size : (batch_size, 128, 7, 7)
                               nn.LeakyReLU(0.2, inplace=True),
                               # out size : (batch_size, 128, 7, 7)
                               nn.Conv2d(in_channels=128, out_channels=256,
                                         kernel_size=3, stride=2, padding=0, bias=False),
                               # out size : (batch_size, 256, 3, 3)
                               nn.BatchNorm2d(256),
                               # out size : (batch_size, 256, 3, 3)
                               nn.LeakyReLU(0.2, inplace=True),
                               # out size : (batch_size, 256, 3, 3)
                               # Notice in below layer, we are using out channels as 1, we don't need to use Linear layer
                               # Same is recommended in DCGAN paper also
                               nn.Conv2d(in_channels=256, out_channels=1,
                                         kernel_size=3, stride=1, padding=0, bias=False),
                               # out size : (batch_size, 1, 1, 1)
                               # sigmoid layer to convert in [0,1] range
                               nn.Sigmoid()
                               )

  def forward(self, x, y):
    # size of x : (batch_size, 1, 28, 28)
    x = self.layer_x(x)
    # size of x : (batch_size, 32, 14, 14)

    # size of y : (batch_size, LABEL_SIZE, 28, 28)
    y = self.layer_y(y)
    # size of y : (batch_size, 32, 14, 14)

    # concat image layer and label layer output
    xy = torch.cat([x,y], dim=1)
    # size of xy : (batch_size, 64, 14, 14)
    xy = self.layer_xy(xy)
    # size of xy : (batch_size, 1, 1, 1)
    xy = xy.view(xy.shape[0], -1)
    # size of xy : (batch_size, 1)
    return xy

In [None]:
# Create the Discriminator
D = Discriminator(LABEL_SIZE).to(DEVICE)
print(D)
D_total_params = sum(p.numel() for p in D.parameters() if p.requires_grad)
print(f"Total Trainable Parameters: {D_total_params}")

# Generator Model

In [None]:
class Generator(nn.Module):
  """ G(z) """
  def __init__(self, label_size, input_size=100):
    # initalize super module
    super(Generator, self).__init__()

    # noise z input layer : (batch_size, 100, 1, 1)
    self.layer_x = nn.Sequential(nn.ConvTranspose2d(in_channels=100, out_channels=128, kernel_size=3,
                                                  stride=1, padding=0, bias=False),
                                 # out size : (batch_size, 128, 3, 3)
                                 nn.BatchNorm2d(128),
                                 # out size : (batch_size, 128, 3, 3)
                                 nn.ReLU(),
                                 # out size : (batch_size, 128, 3, 3)
                                )

    # label input layer : (batch_size, label_size, 1, 1)
    self.layer_y = nn.Sequential(nn.ConvTranspose2d(in_channels=label_size, out_channels=128, kernel_size=3,
                                                  stride=1, padding=0, bias=False),
                                 # out size : (batch_size, 128, 3, 3)
                                 nn.BatchNorm2d(128),
                                 # out size : (batch_size, 128, 3, 3)
                                 nn.ReLU(),
                                 # out size : (batch_size, 128, 3, 3)
                                )

    # noise z and label concat input layer : (batch_size, 256, 3, 3)
    self.layer_xy = nn.Sequential(nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3,
                                                  stride=2, padding=0, bias=False),
                                  # out size : (batch_size, 128, 7, 7)
                                  nn.BatchNorm2d(128),
                                  # out size : (batch_size, 128, 7, 7)
                                  nn.ReLU(),
                                  # out size : (batch_size, 128, 7, 7)
                                  nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4,
                                                      stride=2, padding=1, bias=False),
                                  # out size : (batch_size, 64, 14, 14)
                                  nn.BatchNorm2d(64),
                                  # out size : (batch_size, 64, 14, 14)
                                  nn.ReLU(),
                                  # out size : (batch_size, 64, 14, 14)
                                  nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=4,
                                                      stride=2, padding=1, bias=False),
                                  # out size : (batch_size, 1, 28, 28)
                                  nn.Sigmoid()
                                  # out size : (batch_size, 1, 28, 28)
                                 )

  def forward(self, x, y):
    # x size : (batch_size, 100)
    x = x.view(x.shape[0], x.shape[1], 1, 1)
    # x size : (batch_size, 100, 1, 1)
    x = self.layer_x(x)
    # x size : (batch_size, 128, 3, 3)

    # y size : (batch_size, LABEL_SIZE)
    y = y.view(y.shape[0], y.shape[1], 1, 1)
    # y size : (batch_size, 100, 1, 1)
    y = self.layer_y(y)
    # y size : (batch_size, 128, 3, 3)

    # concat x and y
    xy = torch.cat([x,y], dim=1)
    # xy size : (batch_size, 256, 3, 3)
    xy = self.layer_xy(xy)
    # xy size : (batch_size, 1, 28, 28)
    return xy

In [None]:
# Create the Generator
G = Generator(LABEL_SIZE).to(DEVICE)
print(G)
G_total_params = sum(p.numel() for p in G.parameters() if p.requires_grad)
print(f"Total Trainable Parameters: {G_total_params}")

# Weight Initialization

In [None]:
# custom weights initialization
def weights_init(net):
    classname = net.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(net.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(net.weight.data, 1.0, 0.02)
        nn.init.constant_(net.bias.data, 0)

In [None]:
# randomly initialize all weights to mean=0, stdev=0.2.
D.apply(weights_init)
G.apply(weights_init)

# Model Training

## Setup



*   Value of beta1 hyperparameter in Adam optimizer has huge impact on stability of generator and DCGAN paper recommend 0.5 value.
*   Recommended learning rate for Adam is 0.0002.

In [None]:
# size of latent vector z
size_z = 100
# number of discriminator steps for each generator step
Ksteps = 1 #@param {type:"integer"}
# number of discriminator steps for each Autoencoder step
Jsteps = 1 #@param {type:"integer"}
# learning rate of adam
# DCGAN recommend 0.0002 lr
Adam_lr = 0.0002 #@param {type:"number"}
# DCGAN recommend 0.5
Adam_beta1 = 0.5 #@param {type:"number"}
# Scaling factor applied to Cross Entropy classification Loss

In [None]:
# We calculate Binary cross entropy loss
discrimination_loss = nn.BCELoss()#reduction='sum')
# Adam optimizer for generator
D_optimizer = torch.optim.Adam(D.parameters(), lr=Adam_lr, betas=(Adam_beta1, 0.999))
# Adam optimizer for discriminator
G_optimizer = torch.optim.Adam(G.parameters(), lr=Adam_lr, betas=(Adam_beta1, 0.999))

reconstruction_loss = nn.MSELoss()#reduction='sum')
classifier_loss = nn.CrossEntropyLoss()#reduction='sum')
AE_optimizer = torch.optim.Adam(AE.parameters(), lr=0.001)

def kl_divergence(mu, logvar):
  return -0.001 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())


In [None]:
# labels for training images x for Discriminator training
labels_real = torch.ones((BATCH_SIZE, 1)).to(DEVICE)
# labels for generated images G(z) for Discriminator training
labels_fake = torch.zeros((BATCH_SIZE, 1)).to(DEVICE)
# Fix noise for testing generator and visualization
z_test = torch.randn(100, size_z).to(DEVICE)

In [None]:
img_size = 28

# convert labels to onehot encoding
onehot = torch.zeros(10, 10).scatter_(1, torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).view(10,1), 1).to(DEVICE)
# reshape labels to image size, with number of labels as channel
fill = torch.zeros([10, LABEL_SIZE, img_size, img_size]).to(DEVICE)
#channel corresponding to label will be set one and all other zeros
for i in range(10):
  fill[i, i, :, :] = 1
# create labels for testing generator
test_y = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]*10).type(torch.LongTensor)
# convert to one hot encoding
test_Gy = onehot[test_y].to(DEVICE)

## Save Checkpoint

In [None]:
AE_path0 = "./AE_checkpoint0"
torch.save({
            'model_state_dict': AE.state_dict(),
            'optimizer_state_dict': AE_optimizer.state_dict(),
            }, AE_path0)

G_path0 = "./G_checkpoint0"
torch.save({
            'model_state_dict': G.state_dict(),
            'optimizer_state_dict': G_optimizer.state_dict(),
            }, G_path0)

D_path0 = "./D_checkpoint0"
torch.save({
            'model_state_dict': D.state_dict(),
            'optimizer_state_dict': D_optimizer.state_dict(),
            }, D_path0)

## Train the VAE

### Load checkpoint

In [None]:
AE_checkpoint = torch.load(AE_path0)
AE.load_state_dict(AE_checkpoint['model_state_dict'])
AE_optimizer.load_state_dict(AE_checkpoint['optimizer_state_dict'])

In [None]:
G_checkpoint = torch.load(G_path0)
G.load_state_dict(G_checkpoint['model_state_dict'])
G_optimizer.load_state_dict(G_checkpoint['optimizer_state_dict'])

In [None]:
D_checkpoint = torch.load(D_path0)
D.load_state_dict(D_checkpoint['model_state_dict'])
D_optimizer.load_state_dict(D_checkpoint['optimizer_state_dict'])

### Begin Training

In [None]:
def get_class_examples(AE_dataloader):
  # Initialize a tensor filled with zeros to store example images of each class
  example_entries = torch.zeros(10, 1, 28, 28)
  # Initialize a list to keep track of the number of seen classes
  class_count = [0] * 10
  # Iterate through each batch of images and classes in the dataloader
  for images, classes in AE_dataloader:
    # Iterate through each image and class in the batch
    for i in range(images.shape[0]):
      class_label = classes[i].item()
      # If the class has not been seen yet, store the first image of that class
      # in example_entries and increment the count for that class
      if class_count[class_label] == 0:
        example_entries[class_label] = images[i]
        class_count[class_label] += 1
        # If 10 classes have been counted, return example_entries
        if sum(class_count) >= 10:
          return example_entries
  # Return example_entries even if there are less than 10 unique classes in the dataloader
  return example_entries

class_examples = get_class_examples(dataloader).to(DEVICE)
val_class_examples = get_class_examples(testloader).to(DEVICE)

In [None]:
def display_class_examples(tensors):
  # Reshape the tensors to [28, 28]
  images = [np.squeeze(tensor, axis=0).cpu() for tensor in tensors]

  # Create a figure with a grid of subplots
  fig, axes = plt.subplots(nrows=1, ncols=10, figsize=(20, 2.5))

  # Flatten the axes array
  axes = axes.flatten()

  # Iterate over the images and add them to the subplots
  for image, ax in zip(images, axes):
    ax.imshow(image, cmap='gray')
    ax.axis('off')

  # Show the plot
  plt.show()

In [None]:
display_class_examples(class_examples)

In [None]:
display_class_examples(val_class_examples)

In [None]:
def AE_epoch(dataloader, AE, AE_optimizer, reconstruction_loss, train=True):
  epoch_AE_losses = []
  epoch_AE_reconstruction_losses = []
  for images, classes in dataloader:
    images = Variable(images).to(DEVICE)
    classes = Variable(classes).to(DEVICE)
    ############################
    # Forward Pass Through Autoencoder
    ############################
    reconstruction, _, mu, logvar = AE(images)
    # Calculate the Mean Squared Error loss between the original and reconstructed image
    AE_reconstruction_loss = reconstruction_loss(reconstruction, images)
    AE_KLD = kl_divergence(mu, logvar)

    ############################
    # Update Autoencoder
    ############################
    AE_loss = AE_reconstruction_loss + AE_KLD

    # save values for plots
    epoch_AE_losses.append(AE_loss.item())
    epoch_AE_reconstruction_losses.append(AE_reconstruction_loss.item())

    if train:
      # zero accumalted grads
      AE_optimizer.zero_grad()
      # do backward pass
      AE_loss.backward()
      # update autoencoder model
      AE_optimizer.step()

  ############################
  # Log
  ############################
  epoch_AE_loss = sum(epoch_AE_losses)/ len(epoch_AE_losses)
  epoch_AE_reconstruction_loss = sum(epoch_AE_reconstruction_losses)/ len(epoch_AE_reconstruction_losses)
  return epoch_AE_loss, epoch_AE_reconstruction_loss

In [None]:
if ARCHITECHTURE == 'VAE' or ARCHITECHTURE == 'cDCGAN' or \
   ARCHITECHTURE == 'VAE and cDCGAN':
  EPOCHS = NUM_EPOCHS
else:
  EPOCHS = NUM_PRETRAIN_EPOCHS

if PRE_TRAIN or ARCHITECHTURE == 'VAE' or ARCHITECHTURE == 'VAE and cDCGAN':
  AE_losses = []
  reconstruction_losses = []

  val_AE_losses = []
  val_reconstruction_losses = []

  counter = 0
  for epoch in range(EPOCHS):
    ############################
    # Train
    ############################
    losses = AE_epoch(dataloader, AE, AE_optimizer,
                      reconstruction_loss, train=True)

    AE_losses.append(losses[0])
    reconstruction_losses.append(losses[1])

    ############################
    # Display
    ############################
    AE.eval()
    print('epoch [{}/{}], loss:{:.3f}, reconstruction:{:.3f}'.format(epoch+1, EPOCHS,
                                                          AE_losses[-1],
                                                          reconstruction_losses[-1]))
    with torch.no_grad():
      reconstructed_class_examples, _, _, _ = AE(class_examples)
      display_class_examples(class_examples)
      display_class_examples(reconstructed_class_examples)

      if counter % EPOCHS_BETWEEN_VAL == 0:
        ############################
        # Validate
        ############################
        val_losses = AE_epoch(testloader, AE, AE_optimizer,
                              reconstruction_loss, train=False)

        val_AE_losses.append(val_losses[0])
        val_reconstruction_losses.append(val_losses[1])

        ############################
        # Display
        ############################
        print('Validation loss:{:.3f}, reconstruction:{:.3f}'.format(val_AE_losses[-1],
                                                          val_reconstruction_losses[-1]))
        reconstructed_val_class_examples, _, _, _ = AE(val_class_examples)
        display_class_examples(val_class_examples)
        display_class_examples(reconstructed_val_class_examples)

    AE.train()
    counter += 1

### Visualizing VAE Results

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'VAE' or ARCHITECHTURE == 'VAE and cDCGAN':
  def plot_losses(training_losses, validation_losses, offset_factor):
      x_vals = [i * offset_factor for i in range(len(validation_losses))]
      plt.plot(x_vals, validation_losses, label='Validation Loss')
      plt.plot(range(len(training_losses)), training_losses,
              label='Training Loss')
      plt.xlabel('Iteration')
      plt.ylabel('Loss')
      plt.legend()
      plt.show()
  plot_losses(AE_losses, val_AE_losses, EPOCHS_BETWEEN_VAL)

In [None]:
def sort_data(data, labels):
  """
  Sorts the data and labels tensors by label.
  Args:
    data: a tensor of data samples.
    labels: a tensor of corresponding labels.
  Returns:
    A list of sorted data tensors, where the i-th tensor in the list contains all data samples with label i.
  """
  # Create a list of 10 empty tensors to store the sorted data
  sorted_data = [torch.empty((0, data.shape[1])).cuda() for _ in range(10)]

  # Iterate through the data and labels tensors
  for d, l in zip(data, labels):
    # Append the data entry to the appropriate tensor in the sorted_data list
    sorted_data[l] = torch.cat((sorted_data[l], d.unsqueeze(0)))

  return sorted_data

if PRE_TRAIN or ARCHITECHTURE == 'VAE' or ARCHITECHTURE == 'VAE and cDCGAN':
  encoded_class_sums = [torch.zeros(ENCODED_DIM).cuda() for _ in range(10)]
  class_totals = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  for img, classes in dataloader:
    images = Variable(img).cuda()
    labels = Variable(classes).cuda()
    with torch.no_grad():
      _, encoded_images, _, _ = AE(images)
    # Sort the encoded images by label
    sorted_encoded_images = sort_data(encoded_images, labels)
    # Iterate through the sorted encoded image tensors
    for i in range(len(sorted_encoded_images)):
      # Add the sum of the encoded images in the current tensor to the encoded class sum for this label
      encoded_class_sums[i] = torch.add(torch.sum(sorted_encoded_images[i], 0),
                                        encoded_class_sums[i])

      # Increment the class total for this label by the number of encoded images in the current tensor
      class_totals[i] += len(sorted_encoded_images[i])

  # Initialize a list to store the class encodings
  class_encodings = [torch.zeros(ENCODED_DIM) for _ in range(10)]

  # Iterate through the encoded class sums
  for l in range(len(encoded_class_sums)):
    # Calculate the class encoding as the mean of the encoded images for each class
    class_encodings[l] = torch.div(encoded_class_sums[l], class_totals[l]).unsqueeze(0)

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'VAE' or ARCHITECHTURE == 'VAE and cDCGAN':
  tensors = []
  for i in range(len(class_encodings)):
    with torch.no_grad():
      tensors.append(AE.decoder(class_encodings[i]).detach())
  # Reshape the tensors to [28, 28]
  images = [np.squeeze(tensor).cpu() for tensor in tensors]

  # Create a figure with a grid of subplots
  fig, axes = plt.subplots(nrows=1, ncols=10, figsize=(20, 2.5))

  # Flatten the axes array
  axes = axes.flatten()

  # Iterate over the images and add them to the subplots
  for image, ax in zip(images, axes):
      ax.imshow(image, cmap='gray')
      ax.axis('off')

  # Show the plot
  plt.show()

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'VAE' or ARCHITECHTURE == 'VAE and cDCGAN':
  image_features = torch.stack(class_encodings, dim=0).squeeze()
  image_features /= image_features.norm(dim=-1, keepdim=True)
  similarity = image_features.cpu().numpy() @ image_features.cpu().numpy().T
  plt.figure(figsize=(20, 14))
  plt.imshow(similarity)
  plt.yticks(range(10), fontsize=18)
  plt.xticks(range(10), fontsize=18)
  for x in range(similarity.shape[1]):
      for y in range(similarity.shape[0]):
          plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center",
                   size=12)

  for side in ["left", "top", "right", "bottom"]:
    plt.gca().spines[side].set_visible(False)

  plt.xlim([-0.5, 10 - 0.5])
  plt.ylim([9 + 0.5, -2])

  plt.title("Cosine similarity between averaged image features", size=20)

In [None]:
def show_image(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

AE.eval()

with torch.no_grad():
    # calculate mean and std of latent code, generated takining in test images as inputs
    images, labels = next(iter(dataloader))
    images = images.to(DEVICE)
    _, latent, _, _ = AE(images)
    latent = latent.cpu()

    mean = latent.mean(dim=0)
    std = (latent - mean).pow(2).mean(dim=0).sqrt()

    # sample latent vectors from the normal distribution
    latent = torch.randn(128, ENCODED_DIM)*std + mean

    # reconstruct images from the random latent vectors
    latent = latent.to(DEVICE)
    img_recon = AE.decoder(latent)
    img_recon = img_recon.cpu()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(img_recon[:100],10,5))
    plt.show()

In [None]:
AE.eval()

with torch.no_grad():
    # calculate mean and std of latent code, generated takining in test images as inputs
    images, labels = next(iter(dataloader))
    images = images.to(DEVICE)
    _, latent, mu, logvar = AE(images)
    latent = latent.cpu()

    std = 10 * torch.exp(0.5*logvar)

    z = torch.randn_like(std)

    # sample latent vectors from the normal distribution
    latent = mu + (z * std)

    # reconstruct images from the random latent vectors
    latent = latent.to(DEVICE)
    img_recon = AE.decoder(latent)
    img_recon = img_recon.cpu()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(img_recon[:100],10,5))
    plt.show()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(images[:100].cpu(),10,5))
    plt.show()

In [None]:
with torch.no_grad():
    # calculate mean and std of latent code, generated takining in test images as inputs
    images, labels = next(iter(dataloader))
    images = images.to(DEVICE)
    _, latent, mu, logvar = AE(images)
    latent = latent.cpu()

    std = 10 * torch.exp(0.5*logvar)

    z = torch.randn_like(std)

    # sample latent vectors from the normal distribution
    latent = mu[0] + (z * std[0])

    # reconstruct images from the random latent vectors
    latent = latent.to(DEVICE)
    img_recon = AE.decoder(latent)
    img_recon = img_recon.cpu()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(img_recon[:100],10,5))
    plt.show()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(images[0].cpu(),10,5))
    plt.show()

In [None]:
AE.eval()

with torch.no_grad():
    # calculate mean and std of latent code, generated takining in test images as inputs
    images, labels = next(iter(testloader))
    images = images.to(DEVICE)
    _, latent, mu, log_var = AE(images)
    latent = latent.cpu()

    mean = mu.mean(dim=0).cpu()
    std = torch.exp(0.5*log_var).mean(dim=0).cpu()

    # sample latent vectors from the normal distribution
    latent = torch.randn(128, ENCODED_DIM)*std + mean

    # reconstruct images from the random latent vectors
    latent = latent.to(DEVICE)
    img_recon = AE.decoder(latent)
    img_recon = img_recon.cpu()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(img_recon[:100],10,5))
    plt.show()

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'VAE' or ARCHITECHTURE == 'VAE and cDCGAN':
  with torch.no_grad():
    for i in range(10):
      for j in range(10):
        tensor1 = class_examples[i]
        tensor2 = class_examples[j]
        n=20
        interpolated_tensors = []
        for k in range(n):
            alpha = k / (n - 1)
            interpolated_tensor = (1 - alpha) * tensor1 + alpha * tensor2
            interpolated_tensors.append(interpolated_tensor)

        # Convert the list of tensors to a single tensor
        interpolated_tensor = torch.stack(interpolated_tensors)
        reconstructed, encoded, _, _ = AE(interpolated_tensor)
        # Reshape the tensors to [28, 28]
        images = [np.squeeze(tensor, axis=0).cpu() for tensor in reconstructed]

        # Create a figure with a grid of subplots
        fig, axes = plt.subplots(nrows=1, ncols=20, figsize=(20, 2.5))

        # Flatten the axes array
        axes = axes.flatten()

        # Iterate over the images and add them to the subplots
        for image, ax in zip(images, axes):
          ax.imshow(image, cmap='gray')
          ax.axis('off')
        # Show the plot
        plt.show()
        print("------------------------------------------------------------------------------------------------------------")

### Save Checkpoint

In [None]:
AE_path1 = "./AE_checkpoint1"
torch.save({
            'model_state_dict': AE.state_dict(),
            'optimizer_state_dict': AE_optimizer.state_dict(),
            }, AE_path1)

G_path1 = "./G_checkpoint1"
torch.save({
            'model_state_dict': G.state_dict(),
            'optimizer_state_dict': G_optimizer.state_dict(),
            }, G_path1)

D_path1 = "./D_checkpoint1"
torch.save({
            'model_state_dict': D.state_dict(),
            'optimizer_state_dict': D_optimizer.state_dict(),
            }, D_path1)

## Train the cDCGAN

### Load checkpoint

In [None]:
AE_checkpoint = torch.load(AE_path1)
AE.load_state_dict(AE_checkpoint['model_state_dict'])
AE_optimizer.load_state_dict(AE_checkpoint['optimizer_state_dict'])

In [None]:
G_checkpoint = torch.load(G_path1)
G.load_state_dict(G_checkpoint['model_state_dict'])
G_optimizer.load_state_dict(G_checkpoint['optimizer_state_dict'])

In [None]:
D_checkpoint = torch.load(D_path1)
D.load_state_dict(D_checkpoint['model_state_dict'])
D_optimizer.load_state_dict(D_checkpoint['optimizer_state_dict'])

### Begin Training

In [None]:
def GAN_epoch(dataloader, AE, D, D_optimizer, G, G_optimizer,
              discrimination_loss, reconstruction_loss, train=True):
  torch.autograd.set_detect_anomaly(True)
  epoch_D_losses = []
  epoch_D_x_losses = []
  epoch_D_z_losses = []
  epoch_D_y_losses = []
  epoch_Dx = []

  epoch_G_losses = []
  epoch_G_z_losses = []
  epoch_G_z_reconstruction_losses = []
  epoch_G_y_losses = []
  epoch_DGz = []
  epoch_DGy = []

  step = 0
  # iterate through data loader generator object
  for images, classes in dataloader:
    images = Variable(images).to(DEVICE)
    ############################
    # Forward Pass Through Autoencoder
    ############################
    _, encoded_images, mu, logvar = AE(images)

    ############################
    # Forward Pass Through Generator
    ############################
    # create latent vector z from normal distribution
    z = torch.randn(BATCH_SIZE, size_z).to(DEVICE)
    # generate image
    fake_images = G(z, encoded_images.detach())

    ############################
    # Sample From AE's Latent Space
    ###########################
    mean = mu.repeat(NUM_SAMPLES, 1)
    std = STD_SCALING * torch.exp(0.5*logvar).repeat(NUM_SAMPLES, 1)
    y = torch.randn_like(std)

    # sample latent vectors from the normal distribution
    samples = mean + (y * std)

    # reconstruct images from the random latent vectors
    samples = samples.to(DEVICE)
    reconstructed_samples = AE.decoder(samples.detach())

    ############################
    # Forward Pass Latent Samples Through Generator
    ############################
    # generate image
    fake_samples = G(z.repeat(NUM_SAMPLES, 1), samples.detach())

    ############################
    # Calculate Discriminator loss on real images
    ############################
    # D_x shape will be (batch_size, ENCODED_DIM, 28, 28)
    D_x = encoded_images.unsqueeze(2).unsqueeze(3).detach().repeat(1, 1, img_size, img_size)
    # forward pass D(x)
    x_preds = D(images, D_x)
    # calculate loss log(D(x))
    D_x_loss = discrimination_loss(x_preds, labels_real)

    ############################
    # Calculate Discriminator loss on fake images
    ############################
    # forward pass D(G(z))
    z_preds = D(fake_images.detach(), D_x)
    # calculate loss log(1 - D(G(z)))
    D_z_loss = discrimination_loss(z_preds, labels_fake)

    ############################
    # Calculate Discriminator loss on fake latent samples
    ############################
    # D_y shape will be (batch_size, ENCODED_DIM, 28, 28)
    D_y = D_x.repeat(NUM_SAMPLES, 1, 1, 1)
    # forward pass D(G(z,y))
    y_preds = D(fake_samples.detach(), D_y)
    # calculate loss log(1 - D(G(z,y)))
    D_y_loss = discrimination_loss(y_preds, labels_fake.repeat(NUM_SAMPLES, 1))

    ############################
    # Update D network
    ############################
    D_loss = D_x_loss + (D_z_loss + D_y_loss)/2

    # save values for plots
    epoch_D_losses.append(D_loss.item())
    epoch_D_x_losses.append(D_x_loss.item())
    epoch_D_z_losses.append(D_z_loss.item()/2)
    epoch_D_y_losses.append(D_y_loss.item()/2)
    epoch_Dx.append(x_preds.mean().item())

    if train:
      # zero accumalted grads
      D.zero_grad()
      # do backward pass
      D_loss.backward()
      # update discriminator model
      D_optimizer.step()

    ############################
    # Update G network
    ############################
    # if Ksteps of Discriminator training are done, update generator
    if step % Ksteps == 0:
      # As we done one step of discriminator, again calculate D(G(z))
      # forward pass D(G(z))
      z_out = D(fake_images, D_x)
      # calculate loss log(D(G(z)))
      G_z_loss = discrimination_loss(z_out, labels_real)
      # Calculate the Mean Squared Error loss between the original and generated image
      G_z_reconstruction_loss = reconstruction_loss(fake_images, images)

      # forward pass D(G(z))
      y_out = D(fake_samples, D_y)
      # calculate loss log(D(G(z)))
      G_y_loss = discrimination_loss(y_out, labels_real.repeat(NUM_SAMPLES, 1))
      G_y_reconstruction_loss = reconstruction_loss(fake_samples, reconstructed_samples)

      G_loss = G_z_loss + G_y_loss# + G_y_reconstruction_loss# + G_z_reconstruction_loss

      # save values for plots
      epoch_G_losses.append(G_loss.item())
      epoch_G_z_losses.append(G_z_loss.item())
      epoch_G_z_reconstruction_losses.append(G_z_reconstruction_loss.item())
      epoch_G_y_losses.append(G_y_loss.item())
      epoch_DGz.append(z_out.mean().item())
      epoch_DGy.append(y_out.mean().item())

      if train:
        # zero accumalted grads
        G.zero_grad()
        # do backward pass
        G_loss.backward()
        # update generator model
        G_optimizer.step()
    step += 1

  ############################
  # Log
  ############################
  epoch_D_loss = sum(epoch_D_losses)/ len(epoch_D_losses)
  epoch_D_x_loss = sum(epoch_D_x_losses)/ len(epoch_D_x_losses)
  epoch_D_z_loss = sum(epoch_D_z_losses)/ len(epoch_D_z_losses)
  epoch_D_y_loss = sum(epoch_D_y_losses)/ len(epoch_D_y_losses)
  epoch_Dx = sum(epoch_Dx)/ len(epoch_Dx)

  epoch_G_loss = sum(epoch_G_losses)/ len(epoch_G_losses)
  epoch_G_z_loss = sum(epoch_G_z_losses)/ len(epoch_G_z_losses)
  epoch_G_z_reconstruction_loss = sum(epoch_G_z_reconstruction_losses)/ len(epoch_G_z_reconstruction_losses)
  epoch_G_y_loss = sum(epoch_G_y_losses)/ len(epoch_G_y_losses)
  epoch_DGz = sum(epoch_DGz)/ len(epoch_DGz)
  epoch_DGy = sum(epoch_DGy)/ len(epoch_DGy)

  return epoch_D_loss, epoch_D_x_loss, epoch_D_z_loss, epoch_D_y_loss, epoch_Dx, \
         epoch_G_loss, epoch_G_z_loss, epoch_G_z_reconstruction_loss, epoch_G_y_loss, \
         epoch_DGz, epoch_DGy

In [None]:
_, encoded_class_examples, _, _ = AE(class_examples)
class_examples_z = torch.randn(10, size_z).to(DEVICE)

_, encoded_val_class_examples, _, _ = AE(val_class_examples)
val_class_examples_z = torch.randn(10, size_z).to(DEVICE)

if PRE_TRAIN or ARCHITECHTURE == 'cDCGAN' or ARCHITECHTURE == 'VAE and cDCGAN':
  # List of values, which will be used for plotting purpose
  D_losses = []
  D_x_losses = []
  D_z_losses = []
  D_y_losses = []
  Dx_values = []

  G_losses = []
  G_z_losses = []
  G_z_reconstruction_losses = []
  G_y_losses = []
  DGz_values = []
  DGy_values = []

  val_D_losses = []
  val_D_x_losses = []
  val_D_z_losses = []
  val_D_y_losses = []
  val_Dx_values = []

  val_G_losses = []
  val_G_z_losses = []
  val_G_z_reconstruction_losses = []
  val_G_y_losses = []
  val_DGz_values = []
  val_DGy_values = []

  counter = 0
  AE.eval()
  for epoch in range(EPOCHS):
    ############################
    # Train
    ############################
    losses = GAN_epoch(dataloader, AE, D, D_optimizer, G, G_optimizer,
                                  discrimination_loss, reconstruction_loss,
                                  train=True)

    D_losses.append(losses[0])
    D_x_losses.append(losses[1])
    D_z_losses.append(losses[2])
    D_y_losses.append(losses[3])
    Dx_values.append(losses[4])

    G_losses.append(losses[5])
    G_z_losses.append(losses[6])
    G_z_reconstruction_losses.append(losses[7])
    G_y_losses.append(losses[8])
    DGz_values.append(losses[9])
    DGy_values.append(losses[10])

    ############################
    # Display
    ############################
    G.eval()
    D.eval()
    print(f"Epoch {epoch+1}/{EPOCHS} Discriminator Loss {D_losses[-1]:.3f} "
        + f"Generator Loss {G_losses[-1]:.3f} "
        + f"D(x) {Dx_values[-1]:.3f} D(G(z)) {DGz_values[-1]:.3f} D(G(y)) {DGy_values[-1]:.3f}")
    with torch.no_grad():
      fake_class_examples = G(class_examples_z, encoded_class_examples)
      display_class_examples(class_examples)
      display_class_examples(fake_class_examples)
      _, latent, mu, logvar = AE(class_examples)
      std = STD_SCALING * torch.exp(0.5*logvar)
      z = torch.randn_like(std)
      # sample latent vectors from the normal distribution
      latent = mu + (z * std)
      samples = latent.to(DEVICE)
      fake_samples = G(class_examples_z, samples)
      reconstructed_samples = AE.decoder(samples)
      display_class_examples(fake_samples)
      display_class_examples(reconstructed_samples)

      if counter % EPOCHS_BETWEEN_VAL == 0:
        ############################
        # Validate
        ############################
        val_losses = GAN_epoch(testloader, AE, D, D_optimizer, G, G_optimizer,
                                                  discrimination_loss, reconstruction_loss,
                                                  train=False)

        val_D_losses.append(losses[0])
        val_D_x_losses.append(losses[1])
        val_D_z_losses.append(losses[2])
        val_D_y_losses.append(losses[3])
        val_Dx_values.append(losses[4])

        val_G_losses.append(losses[5])
        val_G_z_losses.append(losses[6])
        val_G_z_reconstruction_losses.append(losses[7])
        val_G_y_losses.append(losses[8])
        val_DGz_values.append(losses[9])
        val_DGy_values.append(losses[10])

        ############################
        # Display
        ############################
        print(f"Validation Discriminator Loss {val_D_losses[-1]:.3f} "
            + f"Generator Loss {val_G_losses[-1]:.3f} "
            + f"Generator Reconstruction Loss {val_G_z_reconstruction_losses[-1]:.3f} "
            + f"D(x) {val_Dx_values[-1]:.3f} D(G(x)) {val_DGz_values[-1]:.3f} D(G(y)) {DGy_values[-1]:.3f}")
        fake_val_class_examples = G(val_class_examples_z, encoded_val_class_examples)
        display_class_examples(val_class_examples)
        display_class_examples(fake_val_class_examples)
        _, latent, mu, logvar = AE(class_examples)
        std = STD_SCALING * torch.exp(0.5*logvar)
        z = torch.randn_like(std)
        # sample latent vectors from the normal distribution
        latent = mu + (z * std)
        samples = latent.to(DEVICE)
        fake_samples = G(class_examples_z, samples)
        reconstructed_samples = AE.decoder(samples)
        display_class_examples(fake_samples)
        display_class_examples(reconstructed_samples)

    D.train()
    G.train()
    counter += 1
  AE.train()

### Visualizing GAN results

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'cDCGAN' or ARCHITECHTURE == 'VAE and cDCGAN':
  plt.figure(figsize=(10,5))
  plt.title("Discriminator Loss During Training")
  # plot Discriminator loss
  plt.plot(D_losses,label="D Loss")
  plt.plot(D_x_losses,label="D x Loss")
  plt.plot(D_z_losses,label="D z Loss")
  # get plot axis
  ax = plt.gca()
  # remove right and top spine
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  # add labels and create legend
  plt.xlabel("num_epochs")
  plt.legend()
  plt.show()

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'cDCGAN' or ARCHITECHTURE == 'VAE and cDCGAN':
  plt.figure(figsize=(10,5))
  plt.title("Discriminator Loss During Validation")
  # plot Discriminator loss
  plt.plot(val_D_losses,label="Validation D Loss")
  plt.plot(val_D_x_losses,label="Validation D x Loss")
  plt.plot(val_D_z_losses,label="Validation D z Loss")
  # get plot axis
  ax = plt.gca()
  # remove right and top spine
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  # add labels and create legend
  plt.xlabel("num_epochs")
  plt.legend()
  plt.show()

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'cDCGAN' or ARCHITECHTURE == 'VAE and cDCGAN':
  plt.figure(figsize=(10,5))
  plt.title("Generator Loss During Training")
  # plot Generator loss
  plt.plot(G_losses,label="G Loss")
  plt.plot(G_z_losses,label="G z Loss")
  plt.plot(G_z_reconstruction_losses,label="G reconstruction Loss")
  # get plot axis
  ax = plt.gca()
  # remove right and top spine
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  # add labels and create legend
  plt.xlabel("num_epochs")
  plt.legend()
  plt.show()

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'cDCGAN' or ARCHITECHTURE == 'VAE and cDCGAN':
  plt.figure(figsize=(10,5))
  plt.title("Generator Loss During Validation")
  # plot Generator loss
  plt.plot(val_G_losses,label="Validation G Loss")
  plt.plot(val_G_z_losses,label="Validation G z Loss")
  plt.plot(val_G_z_reconstruction_losses,label="Validation G Reconstruction Loss")
  # get plot axis
  ax = plt.gca()
  # remove right and top spine
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  # add labels and create legend
  plt.xlabel("num_epochs")
  plt.legend()
  plt.show()

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'cDCGAN' or ARCHITECHTURE == 'VAE and cDCGAN':
  plt.figure(figsize=(10,5))
  plt.title("Discriminator and Generator Loss During Training")
  # plot Discriminator and generator loss
  plt.plot(D_losses,label="D Loss")
  plt.plot(G_losses,label="G Loss")
  # get plot axis
  ax = plt.gca()
  # remove right and top spine
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  # add labels and create legend
  plt.xlabel("num_epochs")
  plt.legend()
  plt.show()

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'cDCGAN' or ARCHITECHTURE == 'VAE and cDCGAN':
  plt.figure(figsize=(10,5))
  plt.title("Discriminator and Generator Loss During Validation")
  # plot Discriminator and generator loss
  plt.plot(val_D_losses,label="Validation D Loss")
  plt.plot(val_G_losses,label="Validation G Loss")
  # get plot axis
  ax = plt.gca()
  # remove right and top spine
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  # add labels and create legend
  plt.xlabel("num_epochs")
  plt.legend()
  plt.show()

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'cDCGAN' or ARCHITECHTURE == 'VAE and cDCGAN':
  plt.figure(figsize=(10,5))
  plt.title("Discriminator Accuracy During Training")
  # plot Discriminator and generator loss
  plt.plot(Dx_values,label="D(x)")
  plt.plot(DGz_values,label="D(G(z))")
  # get plot axis
  ax = plt.gca()
  # remove right and top spine
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  # add labels and create legend
  plt.xlabel("num_epochs")
  plt.legend()
  plt.show()

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'cDCGAN' or ARCHITECHTURE == 'VAE and cDCGAN':
  plt.figure(figsize=(10,5))
  plt.title("Difference in Discriminator Accuracy During Training")
  # plot Discriminator and generator loss
  diff = []
  for i in range(len(Dx_values)):
    diff.append(Dx_values[i] - DGz_values[i])
  plt.plot(diff,label="D(x) - D(G(z))")
  # get plot axis
  ax = plt.gca()
  # remove right and top spine
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  # add labels and create legend
  plt.xlabel("num_epochs")
  plt.legend()
  plt.show()

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'cDCGAN' or ARCHITECHTURE == 'VAE and cDCGAN':
  plt.figure(figsize=(10,5))
  plt.title("Discriminator Accuracy During Validation")
  # plot Discriminator and generator loss
  plt.plot(val_Dx_values,label="Validation D(x)")
  plt.plot(val_DGz_values,label="Validation D(G(z))")
  # get plot axis
  ax = plt.gca()
  # remove right and top spine
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  # add labels and create legend
  plt.xlabel("num_epochs")
  plt.legend()
  plt.show()

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'cDCGAN' or ARCHITECHTURE == 'VAE and cDCGAN':
  plt.figure(figsize=(10,5))
  plt.title("Difference in Discriminator Accuracy During Validation")
  # plot Discriminator and generator loss
  diff = []
  for i in range(len(val_Dx_values)):
    diff.append(val_Dx_values[i] - val_DGz_values[i])
  plt.plot(diff,label="Validation D(x) - D(G(z))")
  # get plot axis
  ax = plt.gca()
  # remove right and top spine
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  # add labels and create legend
  plt.xlabel("num_epochs")
  plt.legend()
  plt.show()

### Save Checkpoint

In [None]:
AE_path2 = "./AE_checkpoint2"
torch.save({
            'model_state_dict': AE.state_dict(),
            'optimizer_state_dict': AE_optimizer.state_dict(),
            }, AE_path2)

G_path2 = "./G_checkpoint2"
torch.save({
            'model_state_dict': G.state_dict(),
            'optimizer_state_dict': G_optimizer.state_dict(),
            }, G_path2)

D_path2 = "./D_checkpoint2"
torch.save({
            'model_state_dict': D.state_dict(),
            'optimizer_state_dict': D_optimizer.state_dict(),
            }, D_path2)

## Train HAe/DaES

### Load checkpoint

In [None]:
AE_checkpoint = torch.load(AE_path2)
AE.load_state_dict(AE_checkpoint['model_state_dict'])
AE_optimizer.load_state_dict(AE_checkpoint['optimizer_state_dict'])

In [None]:
G_checkpoint = torch.load(G_path2)
G.load_state_dict(G_checkpoint['model_state_dict'])
G_optimizer.load_state_dict(G_checkpoint['optimizer_state_dict'])

In [None]:
D_checkpoint = torch.load(D_path2)
D.load_state_dict(D_checkpoint['model_state_dict'])
D_optimizer.load_state_dict(D_checkpoint['optimizer_state_dict'])

###Begin Training

In [None]:
def HAeDaES_epoch(dataloader, AE, AE_optimizer, D, D_optimizer, G, G_optimizer,
                 discrimination_loss, reconstruction_loss, train=True):
  epoch_D_losses = []
  epoch_D_x_losses = []
  epoch_D_reconstructed_x_losses = []
  epoch_D_y_losses = []
  epoch_D_reconstructed_y_losses = []
  epoch_D_z_losses = []
  epoch_Dx = []

  epoch_G_losses = []
  epoch_G_y_losses = []
  epoch_G_z_losses = []
  epoch_DGz = []
  epoch_DGy = []

  epoch_AE_losses = []
  epoch_AE_x_losses = []
  epoch_AE_x_reconstruction_losses = []
  epoch_AE_y_losses = []
  epoch_DAEx = []
  epoch_DAEy = []

  step = 0
  # iterate through data loader generator object
  for images, classes in dataloader:
    images = Variable(images).to(DEVICE)
    ############################
    # Forward Pass Through Autoencoder
    ############################
    reconstructed_images, encoded_images, mu, logvar = AE(images)

    ############################
    # Forward Pass Through Generator
    ############################
    # create noise vector z from normal distribution
    z = torch.randn(BATCH_SIZE, size_z).to(DEVICE)
    # generate image
    fake_images = G(z, encoded_images)

    ############################
    # Sample From AE's Latent Space
    ###########################
    mean = mu.repeat(NUM_SAMPLES, 1)
    std = STD_SCALING * torch.exp(0.5*logvar).detach().repeat(NUM_SAMPLES, 1)
    y = torch.randn_like(std)

    # sample latent vectors from the normal distribution
    samples = mean + (y * std)

    # reconstruct images from the random latent vectors
    samples = samples.to(DEVICE)
    reconstructed_samples = AE.decoder(samples)

    ############################
    # Forward Pass Latent Samples Through Generator
    ############################
    # generate image
    fake_samples = G(z.repeat(NUM_SAMPLES, 1), samples)

    ############################
    # Calculate Discriminator loss on real images
    ############################
    # D_x shape will be (batch_size, ENCODED_DIM, 28, 28)
    D_x = encoded_images.unsqueeze(2).unsqueeze(3).detach().repeat(1, 1, img_size, img_size)
    # forward pass D(x)
    x_preds = D(images, D_x)
    # calculate loss log(D(x))
    D_x_loss = discrimination_loss(x_preds, labels_real)

    ############################
    # Calculate Discriminator loss on reconstructed real images
    ############################
    # forward pass D(AE(x))
    reconstructed_x_preds = D(reconstructed_images.detach(), D_x)
    # calculate loss log(1 - D(AE(x)))
    D_reconstructed_x_loss = discrimination_loss(reconstructed_x_preds, labels_fake)

    ############################
    # Calculate Discriminator loss on fake images
    ############################
    # forward pass D(G(z))
    z_preds = D(fake_images.detach(), D_x)
    # calculate loss log(1 - D(G(z)))
    D_z_loss = discrimination_loss(z_preds, labels_fake)

    ############################
    # Calculate Discriminator loss on fake latent samples
    ############################
    # D_y shape will be (batch_size, ENCODED_DIM, 28, 28)
    D_y = D_x.repeat(NUM_SAMPLES, 1, 1, 1)
    # forward pass D(G(z,y))
    y_preds = D(fake_samples.detach(), D_y)
    # calculate loss log(1 - D(G(z,y)))
    D_y_loss = discrimination_loss(y_preds, labels_fake.repeat(NUM_SAMPLES, 1))

    ############################
    # Calculate Discriminator loss on reconstructed latent samples
    ############################
    # forward pass D(AE(y))
    reconstructed_y_preds = D(reconstructed_samples.detach(), D_y)
    # calculate loss log(1 - D(AE(y)))
    D_reconstructed_y_loss = discrimination_loss(reconstructed_y_preds, labels_fake.repeat(NUM_SAMPLES, 1))

    ############################
    # Update D network
    ############################
    D_loss = D_x_loss + (D_reconstructed_x_loss + \
                         D_y_loss + D_reconstructed_y_loss + \
                         D_z_loss)/4

    # save values for plots
    epoch_D_losses.append(D_loss.item())
    epoch_D_x_losses.append(D_x_loss.item())
    epoch_D_reconstructed_x_losses.append(D_reconstructed_x_loss.item()/4)
    epoch_D_y_losses.append(D_y_loss.item()/4)
    epoch_D_reconstructed_y_losses.append(D_reconstructed_y_loss.item()/4)
    epoch_D_z_losses.append(D_z_loss.item()/4)
    epoch_Dx.append(x_preds.mean().item())

    if train:
      # zero accumalted grads
      D.zero_grad()
      # do backward pass
      D_loss.backward()
      # update discriminator model
      D_optimizer.step()

    ############################
    # Update G network and AE
    ############################
    # if Ksteps of Discriminator training are done, update generator
    if step % Ksteps == 0:
      # As we have done one step of discriminator, again calculate
      # forward pass D(G(z))
      z_preds = D(fake_images, D_x)
      # calculate loss log(D(G(z)))
      G_z_loss = discrimination_loss(z_preds, labels_real)

      # forward pass D(AE(x))
      reconstructed_x_preds = D(reconstructed_images, D_x)
      # calculate loss log(D(AE(x)))
      AE_x_loss = discrimination_loss(reconstructed_x_preds, labels_real)

      # forward pass D(G(y))
      y_preds = D(fake_samples, D_y)
      # calculate loss log(D(G(y)))
      G_y_loss = discrimination_loss(y_preds, labels_real.repeat(NUM_SAMPLES, 1))

      # forward pass D(AE(y))
      reconstructed_y_preds = D(reconstructed_samples, D_y)
      # calculate loss log(D(AE(y)))
      AE_y_loss = discrimination_loss(reconstructed_y_preds, labels_real.repeat(NUM_SAMPLES, 1))

      # Calculate the Mean Squared Error loss between the original and reconstructed image
      AE_x_reconstruction_loss = 100 * reconstruction_loss(reconstructed_images, images)
      AE_x_KLD = 100 * kl_divergence(mu, logvar)

      G_loss = G_y_loss + G_z_loss

      AE_loss = G_loss + AE_x_loss + AE_y_loss + AE_x_reconstruction_loss + AE_x_KLD

      # save values for plots
      epoch_G_losses.append(G_loss.item())
      epoch_G_y_losses.append(G_y_loss.item())
      epoch_G_z_losses.append(G_z_loss.item())
      epoch_DGz.append(z_preds.mean().item())
      epoch_DGy.append(y_preds.mean().item())

      epoch_AE_losses.append(AE_loss.item())
      epoch_AE_x_losses.append(AE_x_loss.item())
      epoch_AE_x_reconstruction_losses.append(AE_x_reconstruction_loss.item())
      epoch_AE_y_losses.append(AE_y_loss.item())
      epoch_DAEx.append(reconstructed_x_preds.mean().item())
      epoch_DAEy.append(reconstructed_y_preds.mean().item())

      if train:
        # zero accumalted grads
        AE.zero_grad()
        G.zero_grad()
        # do backward pass
        AE_loss.backward()
        # update generator model
        G_optimizer.step()
        AE_optimizer.step()

    step += 1

  ############################
  # Log
  ############################
  epoch_D_loss = sum(epoch_D_losses)/ len(epoch_D_losses)
  epoch_D_x_loss = sum(epoch_D_x_losses)/ len(epoch_D_x_losses)
  epoch_D_reconstructed_x_loss = sum(epoch_D_reconstructed_x_losses)/ len(epoch_D_reconstructed_x_losses)
  epoch_D_y_loss = sum(epoch_D_y_losses)/ len(epoch_D_y_losses)
  epoch_D_reconstructed_y_loss = sum(epoch_D_reconstructed_y_losses)/ len(epoch_D_reconstructed_y_losses)
  epoch_D_z_loss = sum(epoch_D_z_losses)/ len(epoch_D_z_losses)
  epoch_Dx = sum(epoch_Dx)/ len(epoch_Dx)

  epoch_G_loss = sum(epoch_G_losses)/ len(epoch_G_losses)
  epoch_G_y_loss = sum(epoch_G_y_losses)/ len(epoch_G_y_losses)
  epoch_G_z_loss = sum(epoch_G_z_losses)/ len(epoch_G_z_losses)
  epoch_DGz = sum(epoch_DGz)/ len(epoch_DGz)
  epoch_DGy = sum(epoch_DGy)/ len(epoch_DGy)

  epoch_AE_loss = sum(epoch_AE_losses)/ len(epoch_AE_losses)
  epoch_AE_x_loss = sum(epoch_AE_x_losses)/ len(epoch_AE_x_losses)
  epoch_AE_x_reconstruction_loss = sum(epoch_AE_x_reconstruction_losses)/ len(epoch_AE_x_reconstruction_losses)
  epoch_AE_y_loss = sum(epoch_AE_y_losses)/ len(epoch_AE_y_losses)
  epoch_DAEx = sum(epoch_DAEx)/ len(epoch_DAEx)
  epoch_DAEy = sum(epoch_DAEy)/ len(epoch_DAEy)

  return epoch_D_loss, \
         epoch_D_x_loss, \
         epoch_D_reconstructed_x_loss, \
         epoch_D_y_loss, \
         epoch_D_reconstructed_y_loss, \
         epoch_D_z_loss, \
         epoch_Dx, \
         epoch_G_loss, \
         epoch_G_y_loss, \
         epoch_G_z_loss, \
         epoch_DGz, \
         epoch_DGy, \
         epoch_AE_loss, \
         epoch_AE_x_loss, \
         epoch_AE_x_reconstruction_loss, \
         epoch_AE_y_loss, \
         epoch_DAEx, \
         epoch_DAEy

In [None]:
if ARCHITECHTURE == "HAe/DaES":
  # List of values, which will be used for plotting purpose
  D_losses = []
  D_x_losses = []
  D_reconstructed_x_losses = []
  D_y_losses = []
  D_reconstructed_y_losses = []
  D_z_losses = []
  Dx_values = []

  G_losses = []
  G_y_losses = []
  G_z_losses = []
  DGz_values = []
  DGy_values = []

  AE_losses = []
  AE_x_losses = []
  AE_x_reconstruction_losses = []
  AE_y_losses = []
  DAEx_values = []
  DAEy_values = []

  val_D_losses = []
  val_D_x_losses = []
  val_D_reconstructed_x_losses = []
  val_D_y_losses = []
  val_D_reconstructed_y_losses = []
  val_D_z_losses = []
  val_Dx_values = []

  val_G_losses = []
  val_G_y_losses = []
  val_G_z_losses = []
  val_DGz_values = []
  val_DGy_values = []

  val_AE_losses = []
  val_AE_x_losses = []
  val_AE_x_reconstruction_losses = []
  val_AE_y_losses = []
  val_DAEx_values = []
  val_DAEy_values = []

  counter = 0

  for epoch in range(NUM_EPOCHS):
    ############################
    # Train
    ############################
    losses = HAeDaES_epoch(dataloader, AE, AE_optimizer, D, D_optimizer,
                                      G, G_optimizer, discrimination_loss,
                                      reconstruction_loss, train=True)
    D_losses.append(losses[0])
    D_x_losses.append(losses[1])
    D_reconstructed_x_losses.append(losses[2])
    D_y_losses.append(losses[3])
    D_reconstructed_y_losses.append(losses[4])
    D_z_losses.append(losses[5])
    Dx_values.append(losses[6])

    G_losses.append(losses[7])
    G_y_losses.append(losses[8])
    G_z_losses.append(losses[9])
    DGz_values.append(losses[10])
    DGy_values.append(losses[11])

    AE_losses.append(losses[12])
    AE_x_losses.append(losses[13])
    AE_x_reconstruction_losses.append(losses[14])
    AE_y_losses.append(losses[15])
    DAEx_values.append(losses[16])
    DAEy_values.append(losses[17])

    ############################
    # Display
    ############################
    AE.eval()
    G.eval()
    D.eval()
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} D Loss {D_losses[-1]:.3f} "
        + f"G Loss {G_losses[-1]:.3f} AE Loss {AE_losses[-1]:.3f} "
        + f"D(x) {Dx_values[-1]:.3f} D(G(z)) {DGz_values[-1]:.3f} "
        + f"D(AE(x)) {DAEx_values[-1]:.3f} D(G(y)) {DGy_values[-1]:.3f} "
        + f"D(AE(y)) {DAEy_values[-1]:.3f}")
    with torch.no_grad():
      reconstructed_class_examples, encoded_class_examples, _, _ = AE(class_examples)
      fake_class_examples = G(class_examples_z, encoded_class_examples)
      display_class_examples(class_examples)
      display_class_examples(fake_class_examples)
      display_class_examples(reconstructed_class_examples)
      _, latent, mu, logvar = AE(class_examples)
      std = STD_SCALING * torch.exp(0.5*logvar)
      z = torch.randn_like(std)
      # sample latent vectors from the normal distribution
      latent = mu + (z * std)
      samples = latent.to(DEVICE)
      fake_samples = G(class_examples_z, samples)
      reconstructed_samples = AE.decoder(samples)
      display_class_examples(fake_samples)
      display_class_examples(reconstructed_samples)


      if counter % EPOCHS_BETWEEN_VAL == 0:
        ############################
        # Validate
        ############################
        val_losses = HAeDaES_epoch(testloader, AE, AE_optimizer, D, D_optimizer,
                                   G, G_optimizer, discrimination_loss,
                                   reconstruction_loss, train=False)
        val_D_losses.append(val_losses[0])
        val_D_x_losses.append(val_losses[1])
        val_D_reconstructed_x_losses.append(val_losses[2])
        val_D_y_losses.append(val_losses[3])
        val_D_reconstructed_y_losses.append(val_losses[4])
        val_D_z_losses.append(val_losses[5])
        val_Dx_values.append(val_losses[6])

        val_G_losses.append(val_losses[7])
        val_G_y_losses.append(val_losses[8])
        val_G_z_losses.append(val_losses[9])
        val_DGz_values.append(val_losses[10])
        val_DGy_values.append(val_losses[11])

        val_AE_losses.append(val_losses[12])
        val_AE_x_losses.append(val_losses[13])
        val_AE_x_reconstruction_losses.append(val_losses[14])
        val_AE_y_losses.append(val_losses[15])
        val_DAEx_values.append(val_losses[16])
        val_DAEy_values.append(val_losses[17])

        ############################
        # Display
        ############################
        print(f"Validation D Loss {val_D_losses[-1]:.3f} "
            + f"G Loss {val_G_losses[-1]:.3f} AE Loss {val_AE_losses[-1]:.3f} "
            + f"D(x) {val_Dx_values[-1]:.3f} D(G(z)) {val_DGz_values[-1]:.3f} "
            + f"D(AE(x)) {val_DAEx_values[-1]:.3f} D(G(y)) {val_DGy_values[-1]:.3f} "
            + f"D(AE(y)) {val_DAEy_values[-1]:.3f}")
        reconstructed_val_class_examples, encoded_val_class_examples, _, _ = AE(val_class_examples)
        fake_val_class_examples = G(val_class_examples_z, encoded_val_class_examples)
        reconstructed_fake_val_class_examples, _, _, _ = AE(fake_val_class_examples)
        display_class_examples(val_class_examples)
        display_class_examples(fake_val_class_examples)
        display_class_examples(reconstructed_val_class_examples)
        _, latent, mu, logvar = AE(val_class_examples)
        std = STD_SCALING * torch.exp(0.5*logvar)
        z = torch.randn_like(std)
        # sample latent vectors from the normal distribution
        latent = mu + (z * std)
        samples = latent.to(DEVICE)
        fake_samples = G(val_class_examples_z, samples)
        reconstructed_samples = AE.decoder(samples)
        display_class_examples(fake_samples)
        display_class_examples(reconstructed_samples)
    D.train()
    G.train()
    AE.train()
    counter += 1

### Visualizing HAe/DaES Results

In [None]:
def plot_losses(training_losses, validation_losses, offset_factor):
    x_vals = [i * offset_factor for i in range(len(validation_losses))]
    plt.plot(x_vals, validation_losses, label='Validation Loss')
    plt.plot(range(len(training_losses)), training_losses,
             label='Training Loss')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()
plot_losses(AE_losses, val_AE_losses, EPOCHS_BETWEEN_VAL)

In [None]:
for i in range(len(D_reconstructed_x_losses)):
  D_reconstructed_x_losses[i] = D_reconstructed_x_losses[i]/4
  D_reconstructed_y_losses[i] = D_reconstructed_y_losses[i]/4
  D_y_losses[i] = D_y_losses[i]/4
  D_z_losses[i] = D_z_losses[i]/4

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'cDCGAN' or ARCHITECHTURE == 'VAE and cDCGAN':
  plt.figure(figsize=(10,5))
  plt.title("Discriminator Loss During Training")
  # plot Discriminator loss
  plt.plot(D_losses,label="Discriminator Loss")
  plt.plot(D_x_losses,label="Discriminator x Loss")
  plt.plot(D_reconstructed_x_losses,label="Discriminator AE(x) Loss")
  plt.plot(D_y_losses,label="Discriminator G(y) Loss")
  plt.plot(D_reconstructed_y_losses,label="Discriminator AE(y) Loss")
  plt.plot(D_z_losses,label="Discriminator G(z) Loss")
  # get plot axis
  ax = plt.gca()
  # remove right and top spine
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  # add labels and create legend
  plt.xlabel("num_epochs")
  plt.legend()
  plt.show()

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'cDCGAN' or ARCHITECHTURE == 'VAE and cDCGAN':
  plt.figure(figsize=(10,5))
  plt.title("Discriminator Loss During Training")
  # plot Discriminator loss
  plt.plot(DGy_values, label="Discriminator Loss")
  plt.plot(D_y_losses, label="Discriminator Loss")
  # get plot axis
  ax = plt.gca()
  # remove right and top spine
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  # add labels and create legend
  plt.xlabel("num_epochs")
  plt.legend()
  plt.show()

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'cDCGAN' or ARCHITECHTURE == 'VAE and cDCGAN':
  plt.figure(figsize=(10,5))
  plt.title("Autoencoder Loss During Training")
  # plot Discriminator loss
  # plt.plot(AE_losses,label="AE Loss")
  # plt.plot(AE_x_losses,label="AE x Loss")
  # plt.plot(AE_x_reconstruction_losses,label="AE x Reconstruction Loss")
  # plt.plot(AE_y_losses,label="AE y Loss")
  # plt.plot(AE_y_reconstruction_losses,label="AE y Reconstruction Loss")
  plt.plot(G_y_losses,label="G y Loss")
  plt.plot(G_z_losses,label="G z Loss")
  # get plot axis
  ax = plt.gca()
  # remove right and top spine
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  # add labels and create legend
  plt.xlabel("num_epochs")
  plt.legend()
  plt.show()

In [None]:
images, _ = next(iter(dataloader))
images = Variable(images).to(DEVICE)
reconstructed_images, encoded_images, _, _ = AE(images)
D_x = encoded_images.unsqueeze(2).unsqueeze(3).detach().repeat(1, 1, img_size, img_size)
x_preds = D(images, D_x)
reconstructed_x_preds = D(reconstructed_images.detach(), D_x)
print(torch.max(x_preds))
print(torch.max(reconstructed_x_preds))
print(torch.min(x_preds))
print(torch.min(reconstructed_x_preds))
print(discrimination_loss(x_preds, labels_real))
print(discrimination_loss(x_preds, labels_fake))
print(discrimination_loss(reconstructed_x_preds, labels_real))
print(discrimination_loss(reconstructed_x_preds, labels_fake))

In [None]:
def sort_data(data, labels):
  """
  Sorts the data and labels tensors by label.
  Args:
    data: a tensor of data samples.
    labels: a tensor of corresponding labels.
  Returns:
    A list of sorted data tensors, where the i-th tensor in the list contains all data samples with label i.
  """
  # Create a list of 10 empty tensors to store the sorted data
  sorted_data = [torch.empty((0, data.shape[1])).cuda() for _ in range(10)]

  # Iterate through the data and labels tensors
  for d, l in zip(data, labels):
    # Append the data entry to the appropriate tensor in the sorted_data list
    sorted_data[l] = torch.cat((sorted_data[l], d.unsqueeze(0)))

  return sorted_data

if PRE_TRAIN or ARCHITECHTURE == 'VAE' or ARCHITECHTURE == 'VAE and cDCGAN':
  encoded_class_sums = [torch.zeros(ENCODED_DIM).cuda() for _ in range(10)]
  class_totals = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  for img, classes in dataloader:
    images = Variable(img).cuda()
    labels = Variable(classes).cuda()
    with torch.no_grad():
      _, encoded_images, _, _ = AE(images)
    # Sort the encoded images by label
    sorted_encoded_images = sort_data(encoded_images, labels)
    # Iterate through the sorted encoded image tensors
    for i in range(len(sorted_encoded_images)):
      # Add the sum of the encoded images in the current tensor to the encoded class sum for this label
      encoded_class_sums[i] = torch.add(torch.sum(sorted_encoded_images[i], 0),
                                        encoded_class_sums[i])

      # Increment the class total for this label by the number of encoded images in the current tensor
      class_totals[i] += len(sorted_encoded_images[i])

  # Initialize a list to store the class encodings
  class_encodings = [torch.zeros(ENCODED_DIM) for _ in range(10)]

  # Iterate through the encoded class sums
  for l in range(len(encoded_class_sums)):
    # Calculate the class encoding as the mean of the encoded images for each class
    class_encodings[l] = torch.div(encoded_class_sums[l], class_totals[l]).unsqueeze(0)

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'VAE' or ARCHITECHTURE == 'VAE and cDCGAN':
  tensors = []
  for i in range(len(class_encodings)):
    with torch.no_grad():
      tensors.append(AE.decoder(class_encodings[i]).detach())
  # Reshape the tensors to [28, 28]
  images = [np.squeeze(tensor).cpu() for tensor in tensors]

  # Create a figure with a grid of subplots
  fig, axes = plt.subplots(nrows=1, ncols=10, figsize=(20, 2.5))

  # Flatten the axes array
  axes = axes.flatten()

  # Iterate over the images and add them to the subplots
  for image, ax in zip(images, axes):
      ax.imshow(image, cmap='gray')
      ax.axis('off')

  # Show the plot
  plt.show()

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'VAE' or ARCHITECHTURE == 'VAE and cDCGAN':
  image_features = torch.stack(class_encodings, dim=0).squeeze()
  image_features /= image_features.norm(dim=-1, keepdim=True)
  similarity = image_features.cpu().numpy() @ image_features.cpu().numpy().T
  plt.figure(figsize=(20, 14))
  plt.imshow(similarity)
  plt.yticks(range(10), fontsize=18)
  plt.xticks(range(10), fontsize=18)
  for x in range(similarity.shape[1]):
      for y in range(similarity.shape[0]):
          plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center",
                   size=12)

  for side in ["left", "top", "right", "bottom"]:
    plt.gca().spines[side].set_visible(False)

  plt.xlim([-0.5, 10 - 0.5])
  plt.ylim([9 + 0.5, -2])

  plt.title("Cosine similarity between averaged image features", size=20)

In [None]:
AE.eval()

with torch.no_grad():
    # calculate mean and std of latent code, generated takining in test images as inputs
    images, labels = next(iter(dataloader))
    images = images.to(DEVICE)
    _, latent, mu, logvar = AE(images)
    latent = latent.cpu()

    std = 10 * torch.exp(0.5*logvar)

    z = torch.randn_like(std)

    # sample latent vectors from the normal distribution
    latent = mu + (z * std)

    # reconstruct images from the random latent vectors
    latent = latent.to(DEVICE)
    img_recon = AE.decoder(latent)
    img_recon = img_recon.cpu()

    z = torch.randn(len(latent), size_z).to(DEVICE)
    img_fake = G(z, latent).cpu()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(img_fake[:100],10,5))
    plt.show()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(img_recon[:100],10,5))
    plt.show()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(images[:100].cpu(),10,5))
    plt.show()

In [None]:
with torch.no_grad():
    # calculate mean and std of latent code, generated takining in test images as inputs
    images, labels = next(iter(dataloader))
    images = images.to(DEVICE)
    _, latent, _, _ = AE(images)
    latent = latent.cpu()

    mean = latent.mean(dim=0)
    std = (latent - mean).pow(2).mean(dim=0).sqrt()

    # sample latent vectors from the normal distribution
    latent = torch.randn(128, ENCODED_DIM)*std + mean

    # reconstruct images from the random latent vectors
    latent = latent.to(DEVICE)
    img_recon = AE.decoder(latent)
    img_recon = img_recon.cpu()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(img_recon[:100],10,5))
    plt.show()

In [None]:
with torch.no_grad():
    # calculate mean and std of latent code, generated takining in test images as inputs
    images, labels = next(iter(testloader))
    images = images.to(DEVICE)
    _, latent, mu, log_var = AE(images)
    latent = latent.cpu()

    mean = mu.mean(dim=0).cpu()
    std = STD_SCALING * torch.exp(0.5*log_var).mean(dim=0).cpu()

    # sample latent vectors from the normal distribution
    latent = torch.randn(128, ENCODED_DIM)*std + mean

    # reconstruct images from the random latent vectors
    latent = latent.to(DEVICE)
    img_recon = AE.decoder(latent)
    img_recon = img_recon.cpu()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(img_recon[:100],10,5))
    plt.show()

In [None]:
with torch.no_grad():
    # calculate mean and std of latent code, generated takining in test images as inputs
    images, labels = next(iter(testloader))
    images = images.to(DEVICE)
    _, latent, mu, log_var = AE(images)
    latent = latent.cpu()

    mean = mu.mean(dim=0).cpu()
    std = STD_SCALING * torch.exp(0.5*log_var).mean(dim=0).cpu()

    # sample latent vectors from the normal distribution
    latent = torch.randn(128, ENCODED_DIM)*std + mean

    # reconstruct images from the random latent vectors
    z = torch.randn(len(latent), size_z).to(DEVICE)
    latent = latent.to(DEVICE)
    img_recon = G(z, latent)
    img_recon = img_recon.cpu()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(img_recon[:100],10,5))
    plt.show()

### Save Checkpoint

In [None]:
AE_path3 = "./AE_checkpoint3"
torch.save({
            'model_state_dict': AE.state_dict(),
            'optimizer_state_dict': AE_optimizer.state_dict(),
            }, AE_path3)

G_path3 = "./G_checkpoint3"
torch.save({
            'model_state_dict': G.state_dict(),
            'optimizer_state_dict': G_optimizer.state_dict(),
            }, G_path3)

D_path3 = "./D_checkpoint3"
torch.save({
            'model_state_dict': D.state_dict(),
            'optimizer_state_dict': D_optimizer.state_dict(),
            }, D_path3)

### Load checkpoint

In [None]:
AE_checkpoint = torch.load(AE_path3)
AE.load_state_dict(AE_checkpoint['model_state_dict'])
AE_optimizer.load_state_dict(AE_checkpoint['optimizer_state_dict'])

In [None]:
G_checkpoint = torch.load(G_path3)
G.load_state_dict(G_checkpoint['model_state_dict'])
G_optimizer.load_state_dict(G_checkpoint['optimizer_state_dict'])

In [None]:
D_checkpoint = torch.load(D_path3)
D.load_state_dict(D_checkpoint['model_state_dict'])
D_optimizer.load_state_dict(D_checkpoint['optimizer_state_dict'])

# Visualizing the Autoencoder

In [None]:
results, _, _, _ = AE(img.to(DEVICE))

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Input tensor of shape [32, 1, 28, 28]
tensor = img.cpu()
print(tensor.shape)

# Reshape the tensor to [32, 28, 28]
tensor = tensor.reshape(-1, 28, 28)

# Create a figure with 8 rows and 4 columns
fig, axes = plt.subplots(8, 4, figsize=(20, 20))
axes = axes.ravel()

# Plot each image in the grid
for i in range(32):
    axes[i].imshow(tensor[i], cmap='gray')
    axes[i].axis('off')

# Show the grid of images
plt.show()

In [None]:
# Input tensor of shape [32, 1, 28, 28]
tensor = results.detach().cpu()

# Reshape the tensor to [32, 28, 28]
tensor = tensor.reshape(-1, 28, 28)

# Create a figure with 8 rows and 4 columns
fig, axes = plt.subplots(8, 4, figsize=(20, 20))
axes = axes.ravel()

# Plot each image in the grid
for i in range(32):
    axes[i].imshow(tensor[i], cmap='gray')
    axes[i].axis('off')

# Show the grid of images
plt.show()

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'VAE' or ARCHITECHTURE == 'VAE and cDCGAN':
  encoded_class_sums = [torch.zeros(ENCODED_DIM).cuda() for _ in range(10)]
  class_totals = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  for img, classes in dataloader:
    images = Variable(img).cuda()
    labels = Variable(classes).cuda()
    with torch.no_grad():
      _, encoded_images, _, _ = AE(images)
    # Sort the encoded images by label
    sorted_encoded_images = sort_data(encoded_images, labels)
    # Iterate through the sorted encoded image tensors
    for i in range(len(sorted_encoded_images)):
      # Add the sum of the encoded images in the current tensor to the encoded class sum for this label
      encoded_class_sums[i] = torch.add(torch.sum(sorted_encoded_images[i], 0),
                                        encoded_class_sums[i])

      # Increment the class total for this label by the number of encoded images in the current tensor
      class_totals[i] += len(sorted_encoded_images[i])

  # Initialize a list to store the class encodings
  class_encodings = [torch.zeros(ENCODED_DIM) for _ in range(10)]

  # Iterate through the encoded class sums
  for l in range(len(encoded_class_sums)):
    # Calculate the class encoding as the mean of the encoded images for each class
    class_encodings[l] = torch.div(encoded_class_sums[l], class_totals[l]).unsqueeze(0)

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'VAE' or ARCHITECHTURE == 'VAE and cDCGAN':
  tensors = []
  for i in range(len(class_encodings)):
    with torch.no_grad():
      tensors.append(AE.decoder(class_encodings[i]).detach())
  # Reshape the tensors to [28, 28]
  images = [np.squeeze(tensor).cpu() for tensor in tensors]

  # Create a figure with a grid of subplots
  fig, axes = plt.subplots(nrows=1, ncols=10, figsize=(20, 2.5))

  # Flatten the axes array
  axes = axes.flatten()

  # Iterate over the images and add them to the subplots
  for image, ax in zip(images, axes):
      ax.imshow(image, cmap='gray')
      ax.axis('off')

  # Show the plot
  plt.show()

In [None]:
image_features = torch.stack(class_encodings, dim=0).squeeze()
image_features /= image_features.norm(dim=-1, keepdim=True)
similarity = image_features.cpu().numpy() @ image_features.cpu().numpy().T
plt.figure(figsize=(20, 14))
plt.imshow(similarity)
plt.yticks(range(10), fontsize=18)
plt.xticks(range(10), fontsize=18)
for x in range(similarity.shape[1]):
    for y in range(similarity.shape[0]):
        plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)

for side in ["left", "top", "right", "bottom"]:
  plt.gca().spines[side].set_visible(False)

plt.xlim([-0.5, 10 - 0.5])
plt.ylim([9 + 0.5, -2])

plt.title("Cosine similarity between averaged image features", size=20)

In [None]:
with torch.no_grad():
    # calculate mean and std of latent code, generated takining in test images as inputs
    images, labels = next(iter(dataloader))
    images = images.to(DEVICE)
    _, latent, _, _ = AE(images)
    latent = latent.cpu()

    mean = latent.mean(dim=0)
    std = (latent - mean).pow(2).mean(dim=0).sqrt()

    # sample latent vectors from the normal distribution
    latent = torch.randn(128, ENCODED_DIM)*std + mean

    # reconstruct images from the random latent vectors
    latent = latent.to(DEVICE)
    img_recon = AE.decoder(latent)
    img_recon = img_recon.cpu()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(img_recon[:100],10,5))
    plt.show()

# Visualizing the GAN



### Plot for Discriminator and Generator loss over the epochs

In [None]:
plt.figure(figsize=(10,5))
plt.title("Discriminator and Generator loss during Training")
# plot Discriminator and generator loss
plt.plot(D_losses,label="D Loss")
plt.plot(G_losses,label="G Loss")
# get plot axis
ax = plt.gca()
# remove right and top spine
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
# add labels and create legend
plt.xlabel("num_epochs")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(10,5))
plt.title("Discriminator accuracy during Training")
# plot Discriminator and generator loss
plt.plot(Dx_values,label="Dx")
plt.plot(DGz_values,label="DGz")
plt.plot(DAEy_values,label="DAEy")
plt.plot(DAEx_values,label="DAEx")
plt.plot(DGy_values,label="DGy")
# get plot axis
ax = plt.gca()
# remove right and top spine
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
# add labels and create legend
plt.xlabel("num_epochs")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(10,5))
plt.title("Discriminator accuracy during testing")
# plot Discriminator and generator loss
plt.plot(val_Dx_values,label="Val D Loss")
plt.plot(val_D_losses,label="Val D(x)")
# get plot axis
ax = plt.gca()
# remove right and top spine
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
# add labels and create legend
plt.xlabel("num_epochs")
plt.legend()
plt.show()

In [None]:
images, labels = next(iter(dataloader))
tensor1 = class_examples[1]
tensor2 = class_examples[2]
n=20
interpolated_tensors = []
for i in range(n):
    alpha = i / (n - 1)
    interpolated_tensor = (1 - alpha) * tensor1 + alpha * tensor2
    interpolated_tensors.append(interpolated_tensor)

# Convert the list of tensors to a single tensor
interpolated_tensor = torch.stack(interpolated_tensors)

In [None]:
with torch.no_grad():
  for i in range(10):
    for j in range(10):
      tensor1 = class_examples[i]
      tensor2 = class_examples[j]
      n=20
      interpolated_tensors = []
      for k in range(n):
          alpha = k / (n - 1)
          interpolated_tensor = (1 - alpha) * tensor1 + alpha * tensor2
          interpolated_tensors.append(interpolated_tensor)

      # Convert the list of tensors to a single tensor
      interpolated_tensor = torch.stack(interpolated_tensors)
      reconstructed, encoded, _, _ = AE(interpolated_tensor)
      # Reshape the tensors to [28, 28]
      images = [np.squeeze(tensor, axis=0).cpu() for tensor in reconstructed]

      # Create a figure with a grid of subplots
      fig, axes = plt.subplots(nrows=1, ncols=20, figsize=(20, 2.5))

      # Flatten the axes array
      axes = axes.flatten()

      # Iterate over the images and add them to the subplots
      for image, ax in zip(images, axes):
        ax.imshow(image, cmap='gray')
        ax.axis('off')
      # Show the plot
      plt.show()

      reconstructed, encoded, _, _ = AE(interpolated_tensor)
      z = torch.randn(len(encoded), size_z).to(DEVICE)
      fake = G(z, encoded)
      # Reshape the tensors to [28, 28]
      images = [np.squeeze(tensor, axis=0).cpu() for tensor in fake]

      # Create a figure with a grid of subplots
      fig, axes = plt.subplots(nrows=1, ncols=20, figsize=(20, 2.5))

      # Flatten the axes array
      axes = axes.flatten()

      # Iterate over the images and add them to the subplots
      for image, ax in zip(images, axes):
        ax.imshow(image, cmap='gray')
        ax.axis('off')

      # Show the plot
      plt.show()
      print("------------------------------------------------------------------------------------------------------------")

In [None]:
with torch.no_grad():
  reconstructed, encoded, _, _ = AE(interpolated_tensor)
  z = torch.randn(len(encoded), size_z).to(DEVICE)
  fake = G(z, encoded)
  # Reshape the tensors to [28, 28]
  images = [np.squeeze(tensor, axis=0).cpu() for tensor in fake]

  # Create a figure with a grid of subplots
  fig, axes = plt.subplots(nrows=1, ncols=20, figsize=(20, 2.5))

  # Flatten the axes array
  axes = axes.flatten()

  # Iterate over the images and add them to the subplots
  for image, ax in zip(images, axes):
    ax.imshow(image, cmap='gray')
    ax.axis('off')

  # Show the plot
  plt.show()

In [None]:
import os
from PIL import Image
import numpy as np

def save_images(images, folder_path, batch_idx):
    # Create folder if it doesn't exist
    os.makedirs(folder_path, exist_ok=True)

    # Loop through images in batch
    for i, image in enumerate(images.cpu()):
        # Convert numpy array to PIL image
        pil_image = Image.fromarray(np.uint8(image[0] * 255))

        # Generate unique file name based on batch index and image index
        file_name = f"{batch_idx * images.shape[0] + i}.png"

        # Save image to folder
        pil_image.save(os.path.join(folder_path, file_name))

with torch.no_grad():
  i=0
  for images, classes in testloader:
    images = Variable(images).to(DEVICE)
    save_images(images, "./original", i)
    i += 1

In [None]:
with torch.no_grad():
  i=0
  for images, classes in testloader:
    images = Variable(images).to(DEVICE)
    reconstruction, encoded, _, _ = AE(images)
    save_images(reconstruction, "./reconstructed", i)

    z = torch.randn(BATCH_SIZE, size_z).to(DEVICE)
    # generate image
    fake_images = G(z, encoded_images)
    save_images(fake_images, "./generated", i)
    i += 1

In [None]:
! pip install pytorch-fid

In [None]:
! python -m pytorch_fid ./original ./reconstructed

In [None]:
! python -m pytorch_fid ./original ./generated

In [None]:
AE_test = Autoencoder(ENCODED_DIM).to(DEVICE)
print(AE_test)
AE_total_params = sum(p.numel() for p in AE_test.parameters() if p.requires_grad)
print(f"Total Trainable Parameters: {AE_total_params}")

In [None]:
AE_test_optimizer = torch.optim.Adam(AE_test.parameters(), lr=0.001)

In [None]:
AE_test_losses = []
test_reconstruction_losses = []

val_AE_test_losses = []
val_test_reconstruction_losses = []

counter = 0
for epoch in range(NUM_EPOCHS + NUM_PRETRAIN_EPOCHS):
  ############################
  # Train
  ############################
  losses = AE_epoch(dataloader, AE_test, AE_test_optimizer,
                    reconstruction_loss, train=True)

  AE_test_losses.append(losses[0])
  test_reconstruction_losses.append(losses[1])

  ############################
  # Display
  ############################
  AE_test.eval()
  print('epoch [{}/{}], loss:{:.3f}, reconstruction:{:.3f}'.format(epoch+1, NUM_EPOCHS + NUM_PRETRAIN_EPOCHS,
                                                        AE_test_losses[-1],
                                                        test_reconstruction_losses[-1]))
  with torch.no_grad():
    reconstructed_class_examples, _, _, _ = AE_test(class_examples)
    display_class_examples(class_examples)
    display_class_examples(reconstructed_class_examples)

    if counter % EPOCHS_BETWEEN_VAL == 0:
      ############################
      # Validate
      ############################
      val_losses = AE_epoch(testloader, AE_test, AE_test_optimizer,
                            reconstruction_loss, train=False)

      val_AE_test_losses.append(val_losses[0])
      val_test_reconstruction_losses.append(val_losses[1])

      ############################
      # Display
      ############################
      print('Validation loss:{:.3f}, reconstruction:{:.3f}'.format(val_AE_test_losses[-1],
                                                        val_test_reconstruction_losses[-1]))
      reconstructed_val_class_examples, _, _, _ = AE_test(val_class_examples)
      display_class_examples(val_class_examples)
      display_class_examples(reconstructed_val_class_examples)

  AE_test.train()
  counter += 1

In [None]:
with torch.no_grad():
  i=0
  for images, classes in testloader:
    images = Variable(images).to(DEVICE)

    reconstruction_test, encoded, _, _ = AE_test(images)
    save_images(reconstruction_test, "./reconstructed_test", i)
    i += 1

In [None]:
! python -m pytorch_fid ./original ./reconstructed_test

#test cDCGAN

In [None]:
# Create the Discriminator
D_test = Discriminator(10).to(DEVICE)
print(D_test)
D_test_total_params = sum(p.numel() for p in D_test.parameters() if p.requires_grad)
print(f"Total Trainable Parameters: {D_test_total_params}")

In [None]:
# Create the Generator
G_test = Generator(10).to(DEVICE)
print(G_test)
G_test_total_params = sum(p.numel() for p in G_test.parameters() if p.requires_grad)
print(f"Total Trainable Parameters: {G_test_total_params}")

In [None]:
# custom weights initialization
def weights_init(net):
    classname = net.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(net.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(net.weight.data, 1.0, 0.02)
        nn.init.constant_(net.bias.data, 0)

In [None]:
# randomly initialize all weights to mean=0, stdev=0.2.
D_test.apply(weights_init)
G_test.apply(weights_init)

In [None]:
# Adam optimizer for generator
optimizerG_test = torch.optim.Adam(G_test.parameters(), lr=Adam_lr, betas=(Adam_beta1, 0.999))
# Adam optimizer for discriminator
optimizerD_test = torch.optim.Adam(D_test.parameters(), lr=Adam_lr, betas=(Adam_beta1, 0.999))

In [None]:
# labels for training images x for Discriminator training
labels_real = torch.ones((BATCH_SIZE, 1)).to(DEVICE)
# labels for generated images G(z) for Discriminator training
labels_fake = torch.zeros((BATCH_SIZE, 1)).to(DEVICE)
# Fix noise for testing generator and visualization
z_test = torch.randn(100, size_z).to(DEVICE)

In [None]:
# convert labels to onehot encoding
onehot = torch.zeros(10, 10).scatter_(1, torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).view(10,1), 1)
# reshape labels to image size, with number of labels as channel
fill = torch.zeros([10, 10, img_size, img_size])
#channel corresponding to label will be set one and all other zeros
for i in range(10):
    fill[i, i, :, :] = 1
# create labels for testing generator
test_y = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]*10).type(torch.LongTensor)
# convert to one hot encoding
test_Gy = onehot[test_y].to(DEVICE)

In [None]:
class_examples_z = torch.randn(10, size_z).to(DEVICE)
val_class_examples_z = torch.randn(10, size_z).to(DEVICE)
print(class_examples_z.shape)

# create a tensor of shape [10,10] with zeros everywhere
label_examples = torch.zeros(10, 10).to(DEVICE)

# set the diagonal elements to ones
label_examples.diagonal().fill_(1)

print(label_examples)

In [None]:
def GAN_test_epoch(dataloader, D, D_optimizer, G, G_optimizer,
              discrimination_loss, train=True):
  torch.autograd.set_detect_anomaly(True)
  epoch_D_losses = []
  epoch_D_x_losses = []
  epoch_D_z_losses = []
  epoch_Dx = []

  epoch_G_losses = []
  epoch_G_z_losses = []
  epoch_DGz = []

  step = 0
  # iterate through data loader generator object
  for images, classes in dataloader:
    images = Variable(images).to(DEVICE)

    ############################
    # Forward Pass Through Generator
    ############################
    # create latent vector z from normal distribution
    z = torch.randn(BATCH_SIZE, size_z).to(DEVICE)
    # create random y labels for generator
    y_gen = (torch.rand(BATCH_SIZE, 1)*10).type(torch.LongTensor).squeeze()
    # convert genarator labels to onehot
    G_y = onehot[y_gen].to(DEVICE)
    # preprocess labels for feeding as y input in D
    # DG_y shape will be (batch_size, 10, 28, 28)
    DG_y = fill[y_gen].to(DEVICE)

    # generate image
    fake_images = G(z, G_y)

    ############################
    # Calculate Discriminator loss on real images
    ############################
    # D_x shape will be (batch_size, ENCODED_DIM, 28, 28)
    D_x = fill[classes].to(DEVICE)
    # forward pass D(x)
    x_preds = D(images, D_x)
    # calculate loss log(D(x))
    D_x_loss = discrimination_loss(x_preds, labels_real)

    ############################
    # Calculate Discriminator loss on fake images
    ############################
    # forward pass D(G(z))
    z_preds = D(fake_images.detach(), DG_y)
    # calculate loss log(1 - D(G(z)))
    D_z_loss = discrimination_loss(z_preds, labels_fake)
    ############################
    # Update D network
    ############################
    D_loss = D_x_loss + D_z_loss

    # save values for plots
    epoch_D_losses.append(D_loss.item())
    epoch_D_x_losses.append(D_x_loss.item())
    epoch_D_z_losses.append(D_z_loss.item())
    epoch_Dx.append(x_preds.mean().item())

    if train:
      # zero accumalted grads
      D.zero_grad()
      # do backward pass
      D_loss.backward()
      # update discriminator model
      D_optimizer.step()

    ############################
    # Update G network
    ############################
    # if Ksteps of Discriminator training are done, update generator
    if step % Ksteps == 0:
      # As we done one step of discriminator, again calculate D(G(z))
      # forward pass D(G(z))
      z_out = D(fake_images, DG_y)
      # calculate loss log(D(G(z)))
      G_z_loss = discrimination_loss(z_out, labels_real)
      # Calculate the Mean Squared Error loss between the original and generated image
      G_z_reconstruction_loss = reconstruction_loss(fake_images, images)

      G_loss = G_z_loss

      # save values for plots
      epoch_G_losses.append(G_loss.item())
      epoch_G_z_losses.append(G_z_loss.item())
      epoch_DGz.append(z_out.mean().item())

      if train:
        # zero accumalted grads
        G.zero_grad()
        # do backward pass
        G_loss.backward()
        # update generator model
        G_optimizer.step()
    step += 1

  ############################
  # Log
  ############################
  epoch_D_loss = sum(epoch_D_losses)/ len(epoch_D_losses)
  epoch_D_x_loss = sum(epoch_D_x_losses)/ len(epoch_D_x_losses)
  epoch_D_z_loss = sum(epoch_D_z_losses)/ len(epoch_D_z_losses)
  epoch_Dx = sum(epoch_Dx)/ len(epoch_Dx)

  epoch_G_loss = sum(epoch_G_losses)/ len(epoch_G_losses)
  epoch_G_z_loss = sum(epoch_G_z_losses)/ len(epoch_G_z_losses)
  epoch_DGz = sum(epoch_DGz)/ len(epoch_DGz)

  return epoch_D_loss, epoch_D_x_loss, epoch_D_z_loss, epoch_Dx, \
         epoch_G_loss, epoch_G_z_loss, \
         epoch_DGz

In [None]:
if PRE_TRAIN or ARCHITECHTURE == 'cDCGAN' or ARCHITECHTURE == 'VAE and cDCGAN':
  # List of values, which will be used for plotting purpose
  D_losses = []
  D_x_losses = []
  D_z_losses = []
  Dx_values = []

  G_losses = []
  G_z_losses = []
  DGz_values = []

  val_D_losses = []
  val_D_x_losses = []
  val_D_z_losses = []
  val_Dx_values = []

  val_G_losses = []
  val_G_z_losses = []
  val_DGz_values = []

  counter = 0
  for epoch in range(NUM_EPOCHS + NUM_PRETRAIN_EPOCHS):
    ############################
    # Train
    ############################
    losses = GAN_test_epoch(dataloader, D_test, optimizerD_test, G_test, optimizerG_test,
                                  discrimination_loss,
                                  train=True)

    D_losses.append(losses[0])
    D_x_losses.append(losses[1])
    D_z_losses.append(losses[2])
    Dx_values.append(losses[3])

    G_losses.append(losses[4])
    G_z_losses.append(losses[5])
    DGz_values.append(losses[6])

    ############################
    # Display
    ############################
    G.eval()
    D.eval()
    print(f"Epoch {epoch+1}/{NUM_EPOCHS + NUM_PRETRAIN_EPOCHS} Discriminator Loss {D_losses[-1]:.3f} "
        + f"Generator Loss {G_losses[-1]:.3f} "
        + f"D(x) {Dx_values[-1]:.3f} D(G(z)) {DGz_values[-1]:.3f}")
    with torch.no_grad():
      fake_class_examples = G_test(class_examples_z, label_examples)
      display_class_examples(fake_class_examples)

      if counter % EPOCHS_BETWEEN_VAL == 0:
        ############################
        # Validate
        ############################
        val_losses = GAN_test_epoch(testloader, D_test, optimizerD_test, G_test, optimizerG_test,
                                                  discrimination_loss,
                                                  train=False)

        val_D_losses.append(losses[0])
        val_D_x_losses.append(losses[1])
        val_D_z_losses.append(losses[2])
        val_Dx_values.append(losses[3])

        val_G_losses.append(losses[4])
        val_G_z_losses.append(losses[5])
        val_DGz_values.append(losses[6])

        ############################
        # Display
        ############################
        print(f"Validation Discriminator Loss {val_D_losses[-1]:.3f} "
            + f"Generator Loss {val_G_losses[-1]:.3f} "
            + f"D(x) {val_Dx_values[-1]:.3f} D(G(x)) {val_DGz_values[-1]:.3f}")
        fake_val_class_examples = G_test(val_class_examples_z, label_examples)
        display_class_examples(fake_val_class_examples)

    D.train()
    G.train()
    counter += 1

In [None]:
with torch.no_grad():
  i=0
  for images, classes in testloader:
    images = Variable(images).to(DEVICE)


    # create latent vector z from normal distribution
    z_test = torch.randn(BATCH_SIZE, size_z).to(DEVICE)
    # create random y labels for generator
    y_gen = (torch.rand(BATCH_SIZE, 1)*10).type(torch.LongTensor).squeeze()
    # convert genarator labels to onehot
    G_y = onehot[y_gen].to(DEVICE)
    fake_images_test = G_test(z_test, G_y)
    save_images(fake_images_test, "./generated_test", i)
    i += 1

In [None]:
! python -m pytorch_fid ./original ./generated_test