In [1]:
# Import
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST # Training dataset
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy
torch.manual_seed(0)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7ffe151a94d0>

In [2]:
from torch.optim.lr_scheduler import ExponentialLR
import prototorch as pt
from prototorch.models import GMLVQ

In [3]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in a uniform grid.
    '''
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [4]:
def get_generator_block(input_dim, output_dim):
    '''
    Function for returning a block of the generator's neural network
    given input and output dimensions.
    Parameters:
        input_dim: the dimension of the input vector, a scalar
        output_dim: the dimension of the output vector, a scalar
    Returns:
        a generator neural network layer, with a linear transformation 
          followed by a batch normalization and then a relu activation
    '''
    return nn.Sequential(
        
        # https://pytorch.org/docs/stable/nn.html for reference
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True)
    )

In [5]:
class Generator(nn.Module):
    '''
    Generator Class
    Values:
        z_dim: the dimension of the noise vector, a scalar
        im_dim: the dimension of the images, fitted for the dataset used, a scalar
          (MNIST images are 28 x 28 = 784 so that is your default)
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
        super(Generator, self).__init__()
        # Build the neural network
        self.gen = nn.Sequential(
            get_generator_block(z_dim, hidden_dim),
            get_generator_block(hidden_dim, hidden_dim * 2),
            get_generator_block(hidden_dim * 2, hidden_dim * 4),
            get_generator_block(hidden_dim * 4, hidden_dim * 8),
            nn.Linear(hidden_dim * 8, im_dim),
            nn.Sigmoid()
        )

    def forward(self, noise):
        '''
        Function for completing a forward pass of the generator: Given a noise tensor, 
        returns generated images.
        Parameters:
            noise: a noise tensor with dimensions (n_samples, z_dim)
        '''
        return self.gen(noise)

In [6]:
def get_noise(n_samples, z_dim, device='cpu'):
    '''
    Function for creating noise vectors: Given the dimensions (n_samples, z_dim),
    creates a tensor of that shape filled with random numbers from the normal distribution.
    Parameters:
        n_samples: the number of samples to generate, a scalar
        z_dim: the dimension of the noise vector, a scalar
        device: the device type
    '''
    return torch.randn(n_samples, z_dim, device=device)



In [7]:
def get_discriminator_block(input_dim, output_dim):
    '''
    Discriminator Block
    Function for returning a neural network of the discriminator given input and output dimensions.
    Parameters:
        input_dim: the dimension of the input vector, a scalar
        output_dim: the dimension of the output vector, a scalar
    Returns:
        a discriminator neural network layer, with a linear transformation 
          followed by an nn.LeakyReLU activation with negative slope of 0.2 
    '''
    return nn.Sequential(
        nn.Linear(input_dim, output_dim),
        nn.LeakyReLU(negative_slope=0.2)
    )

In [33]:
class Discriminator(nn.Module):
    '''
    Discriminator Class
    Values:
        im_dim: the dimension of the images, fitted for the dataset used, a scalar
            (MNIST images are 28x28 = 784 so that is your default)
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, train_data, im_dim=784, hidden_dim=128, num_classes=10, per_class=3):
        super(Discriminator, self).__init__()
        
        # Hyperparameters for the gmlvq layer
        self.hparams = dict(
            input_dim=im_dim,
            latent_dim=im_dim,
            distribution={
                "num_classes": 10,
                "per_class": 3
            },
            proto_lr=0.01,
            bb_lr=0.01,
        )
        self.example_input_array = torch.zeros(im_dim, im_dim)
        
        self.disc = nn.Sequential(
            # get_discriminator_block(im_dim, hidden_dim * 4),
            # get_discriminator_block(hidden_dim * 4, im_dim),
            GMLVQ(
                self.hparams,
                optimizer=torch.optim.Adam,
                prototypes_initializer=pt.initializers.SMCI(train_data),
                lr_scheduler=ExponentialLR,
                lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
            ),
            nn.LeakyReLU(negative_slope=0.2)
        )
        print("generated prototypes: ", self.disc[0].prototypes.shape)
        print("generated prototype labels:", self.disc[0].prototype_labels)
        self.disc[0].add_prototypes({
                "num_classes": 1,
                "per_class": 3
            },
            pt.initializers.UCI((1,im_dim)),)
        print("generated prototypes: ", self.disc[0].prototypes.shape)
        print("generated prototype labels:", self.disc[0].prototype_labels)

    def forward(self, image):
        '''
        Function for completing a forward pass of the discriminator: Given an image tensor, 
        returns a 1-dimension tensor representing fake/real.
        Parameters:
            image: a flattened image tensor with dimension (im_dim)
        '''
        return self.disc(image)

In [34]:
def get_disc_loss(gen, disc, criterion, real, labels, num_images, z_dim, device):
    '''
    Return the loss of the discriminator given inputs.
    Parameters:
        gen: the generator model, which returns an image given z-dimensional noise
        disc: the discriminator model, which returns a single-dimensional prediction of real/fake
        criterion: the loss function, which should be used to compare 
               the discriminator's predictions to the ground truth reality of the images 
               (e.g. fake = 0, real = 1)
        real: a batch of real images
        labels: class labels for the current batch of real images
        num_images: the number of images the generator should produce, 
                which is also the length of the real images
        z_dim: the dimension of the noise vector, a scalar
        device: the device type
    Returns:
        disc_loss: a torch scalar loss value for the current batch
    '''
    noise_vectors = get_noise(num_images, z_dim, device)
    gen_images = gen(noise_vectors).detach()
    # print("gen images shape:", gen_images.shape)
    gen_predictions = disc(gen_images)
    # initializers for ground truths
    ones = torch.ones(num_images,1).float()
    zeros = torch.zeros(num_images,1).float()
    # generating ground truth
    gen_ground_truth = torch.cat([ones,zeros],1) #torch.zeros_like(gen_predictions)
    # print("gen truth shape:", gen_ground_truth.shape)
    gen_loss = criterion(gen_predictions, gen_ground_truth)
    real_predictions = disc(real)
    # print("disc output shape:", real_predictions.shape)
    # generating ground truth
    real_ground_truth = torch.cat([zeros,ones],1) #torch.ones_like(real_predictions) #F.one_hot(labels, num_classes=10).float()
                            # Testing with one-hot vectors as output from the GMLVQ layer is going to be class assignments
    # print("disc truth shape:", real_ground_truth.shape)
    real_loss = criterion(real_predictions, real_ground_truth)
    disc_loss = torch.mean(torch.stack((gen_loss, real_loss)))
    return disc_loss

In [35]:
def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    '''
    Return the loss of the generator given inputs.
    Parameters:
        gen: the generator model, which returns an image given z-dimensional noise
        disc: the discriminator model, which returns a single-dimensional prediction of real/fake
        criterion: the loss function, which should be used to compare 
               the discriminator's predictions to the ground truth reality of the images 
               (e.g. fake = 0, real = 1)
        num_images: the number of images the generator should produce, 
                which is also the length of the real images
        z_dim: the dimension of the noise vector, a scalar
        device: the device type
    Returns:
        gen_loss: a torch scalar loss value for the current batch
    '''
    noise_vectors = get_noise(num_images, z_dim, device)
    gen_images = gen(noise_vectors)
    predictions = disc(gen_images)
    # initializers for ground truths
    ones = torch.ones(num_images,1).float()
    zeros = torch.zeros(num_images,1).float()
    # generating ground truth
    gen_loss = criterion(predictions, torch.cat([zeros,ones],1))
    return gen_loss

In [36]:
# Setting parameters
criterion = nn.BCEWithLogitsLoss()
n_epochs = 10
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001

In [37]:
# Transformations for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.5,), (0.5,)),
    transforms.Lambda(lambda x : x.view(x.shape[0], -1)), # Flatten the batch of real images from the dataset
])

# Dataset
train_ds = MNIST('.', download=False, transform=transform) 

# Load MNIST dataset as tensors
dataloader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True)

In [16]:
# Adding samples for the fake class
rsample = torch.randint(0,256,(30,28,28))
train_ds.data = torch.from_numpy(numpy.r_[train_ds.data, rsample])
train_ds.targets = torch.from_numpy(numpy.r_[train_ds.targets, torch.tensor([10 for i in range(30)])])

In [38]:
### Setting to CPU ###
device = 'cpu'

In [None]:
# Initialize the generator and discriminator
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator(train_data=train_ds).to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

In [None]:
# Training the GAN
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
test_generator = True # Whether the generator should be tested
gen_loss = False
error = False

In [None]:
for epoch in range(n_epochs):
  
    # Dataloader returns the batches
    for real, ylabels in tqdm(dataloader):
        cur_batch_size = len(real)

        # Flatten the batch of real images from the dataset
        # real = real.view(cur_batch_size, -1).to(device)

        ### Update discriminator ###
        #### 
        # Zero out the gradients before backpropagation
        disc_opt.zero_grad()

        # Calculate discriminator loss
        disc_loss = get_disc_loss(gen, disc, criterion, real, ylabels, cur_batch_size, z_dim, device)

        # Update gradients
        disc_loss.backward(retain_graph=True)

        # Update optimizer
        disc_opt.step()
        #### 
        

        # For testing purposes, to keep track of the generator weights
        if test_generator:
            old_generator_weights = gen.gen[0][0].weight.detach().clone()

        ### Update generator ###
        #### 
        gen_opt.zero_grad()
        gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)
        gen_loss.backward(retain_graph=True)
        gen_opt.step()
        #### 

        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step

        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            show_tensor_images(fake, num_images=64)
            show_tensor_images(real, num_images=64)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1