In [29]:
import torch
import math
import random
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [38]:
def show_tensor_images(image_tensor, num_images = 25, size = (1, 28, 28)):
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [39]:
def combineVectors(v1, v2):
    return torch.cat((v1, v2), axis = 1)

In [40]:
def oneHotEncode(numClasses, labels):
    return nn.functional.one_hot(labels, numClasses)

In [27]:
def getInputDimensions(z_dim, shape, numClasses):
    generator_input_dim = z_dim + numClasses
    discriminator_input_dim = shape[0] + numClasses
    return generator_input_dim, discriminator_input_dim

In [41]:
def generate_noise(n_examples, z_dim, device = 'cpu'):
    noise = torch.randn(n_examples, z_dim, device = device)
    return noise

In [29]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

In [30]:
class Generator(nn.Module):
    def __init__(self, z_dim = 10, img_channels = 1, hidden_dims = 64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential (
            self.make_gen_block(z_dim, hidden_dims * 4),
            self.make_gen_block(hidden_dims * 4, hidden_dims * 2, kernel_size = 4, stride = 1),
            self.make_gen_block(hidden_dims * 2, hidden_dims),
            self.make_gen_block(hidden_dims, img_channels, kernel_size = 4, final_layer = True)
        )
    def make_gen_block(self, input_channels, output_channels, kernel_size = 3, stride = 2, final_layer = False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU()
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh()
            )
    def unsqueeze_noise(self, noise):
        return noise.view(len(noise), self.z_dim, 1, 1)
    def forward(self, noise):
        noise_in = self.unsqueeze_noise(noise)
        return self.gen(noise_in)

In [31]:
class Critic(nn.Module):
    def __init__(self, img_channels = 1, hidden_dims = 64):
        super(Critic, self).__init__()
        self.crit = nn.Sequential (
            self.make_crit_block(img_channels, hidden_dims),
            self.make_crit_block(hidden_dims, hidden_dims * 2),
            self.make_crit_block(hidden_dims * 2, 1, final_layer = True)
        )
    def make_crit_block(self, input_channels, output_channels, kernel_size = 4, stride = 2, final_layer = False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace = True)
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride)
            )
    def forward(self, img):
        img_ = self.crit(img)
        return img_.view(len(img_), -1)

In [32]:
def get_gradient(crit, real, fake, epsilon):
    interpolated_img = real * epsilon + fake * (1 - epsilon)
    pred = crit(interpolated_img)
    grad = torch.autograd.grad(
        inputs = interpolated_img,
        outputs = pred,
        grad_outputs=torch.ones_like(pred), 
        create_graph=True,
        retain_graph=True        
    )[0]
    return grad

In [33]:
def gradient_penalty(gradient):
    gradient = gradient.view(len(gradient), -1)
    gradient_norm = gradient.norm(2, dim=1)   
    penalty = torch.mean((gradient_norm - 1) ** 2)
    return penalty

In [34]:
def wasserstein_loss_gen(fake_pred):
    return -torch.mean(fake_pred)

In [35]:
def wasserstein_loss_crit(fake_pred, real_pred, penalty, lambda_ = 0.1):
    crit_loss = torch.mean(fake_pred) - torch.mean(real_pred) + torch.mean(lambda_ * (penalty))
    return crit_loss

In [1]:
def generate_images(generator, n_examples, z_dim, conditional, numClasses = 0, num_images = 25, size = (1, 28, 28), labels = [], device = 'cpu'):
    if conditional == True:
        if len(labels) == 0:
            for i in range(num_images):
                labels.append(random.randint(0, numClasses - 1))
            labels = torch.Tensor(labels).to(device)
        elif len(labels) == 1:
            labels = torch.floor(labels[0] + torch.rand(num_images, ))
            labels = labels.to(device)
        labels = labels.to(torch.int64)
        one_hot = oneHotEncode(numClasses, labels)
        pred_noise = generate_noise(n_examples, z_dim, device)
        pred_noise_labels = combineVectors(pred_noise, one_hot)
        pred_images = generator(pred_noise_labels)
    else:
        pred_noise = generate_noise(n_examples, z_dim, device)
        pred_images = generator(pred_noise)
    show_tensor_images(pred_images, num_images, size)

In [2]:
def GAN(training_set, trained, generator, critic, generator_optimizer, critic_optimizer, epochs, display_step, img_channels, conditional, crit_repeats, learning_rate, beta_1, beta_2, lambda_, z_dim, shape, numClasses, dataloader, loss, device):
    if trained == False:
        if conditional == True:
            generator_input_dim, discriminator_im_chan = getInputDimensions(z_dim, shape, numClasses)
            gen = Generator(generator_input_dim, img_channels).to(device)
            crit = Critic(discriminator_im_chan).to(device)
        else:
            gen = Generator(z_dim, img_channels).to(device)
            crit = Critic(img_channels).to(device)

        gen_optimizer = torch.optim.Adam(gen.parameters(), lr = learning_rate, betas = (beta_1, beta_2))
        crit_optimizer = torch.optim.Adam(crit.parameters(), lr = learning_rate, betas = (beta_1, beta_2))
    
    elif trained == True:
        gen = generator
        crit = critic
        gen_optimizer = generator_optimizer
        crit_optimizer = critic_optimizer


    if trained == False:
        gen = gen.apply(weights_init)
        crit = crit.apply(weights_init)
    
    if loss == 'BCE':
        criterion = nn.BCEWithLogitsLoss()
    
    cur_step = 0
    generator_losses = []
    critic_losses = []
    for epoch in range(epochs):
        for real, labels in tqdm(dataloader):
            cur_batch_size = len(real)
            real = real.to(device)
            if conditional == True:
                one_hot = oneHotEncode(numClasses, labels.to(device)) 
                image_one_hot_labels = one_hot[:, :, None, None]
                image_one_hot_labels = image_one_hot_labels.repeat(1, 1, shape[1], shape[2])
            mean_iteration_critic_loss = 0
            if loss == 'W':
                for _ in range(crit_repeats):
                    crit_optimizer.zero_grad()
                    fake_noise = generate_noise(cur_batch_size, z_dim, device = device)
                    epsilon = torch.rand(len(real), 1, 1, 1, device = device, requires_grad = True)
                    if conditional == True:
                        fake_noise_labels = combineVectors(fake_noise, one_hot)
                        fake_imgs = gen(fake_noise_labels)
                        fake_imgs_and_labels = combineVectors(fake_imgs.detach(), image_one_hot_labels)
                        real_imgs_and_labels = combineVectors(real, image_one_hot_labels)
                        fake_pred = crit(fake_imgs_and_labels)
                        real_pred = crit(real_imgs_and_labels)
                        grad = get_gradient(crit, real_imgs_and_labels, fake_imgs_and_labels, epsilon)
                    else:
                        fake_imgs = gen(fake_noise)
                        fake_pred = crit(fake_imgs)
                        real_pred = crit(real)
                        grad = get_gradient(crit, real, fake_imgs, epsilon)
                    penalty = gradient_penalty(grad)
                    crit_loss = wasserstein_loss_crit(fake_pred, real_pred, penalty, lambda_)
                    mean_iteration_critic_loss += crit_loss.item() / crit_repeats
                    critic_losses += [mean_iteration_critic_loss]
                    crit_loss.backward(retain_graph = True)
                    crit_optimizer.step()
                cur_step += 1
                
                gen_optimizer.zero_grad()
                fake_noise_1 = generate_noise(cur_batch_size, z_dim, device = device)
                if conditional == True:
                    fake_noise_labels_1 = combineVectors(fake_noise_1, one_hot)
                    fake_imgs_1 = gen(fake_noise_labels_1)
                    fake_imgs_1 = combineVectors(fake_imgs_1, image_one_hot_labels)
                else:
                    fake_imgs_1 = gen(fake_noise_1)
                fake_pred_1 = crit(fake_imgs_1)
                gen_loss = wasserstein_loss_gen(fake_pred_1)
                gen_loss.backward(retain_graph = True)
                gen_optimizer.step()
                generator_losses += [gen_loss.item()]
                
            elif loss == 'BCE':
                crit_optimizer.zero_grad()
                fake_noise = generate_noise(cur_batch_size, z_dim, device = device)
                if conditional == True:
                    fake_noise_labels = combineVectors(fake_noise, one_hot)
                    fake_imgs = gen(fake_noise_labels)
                    fake_imgs_and_labels = combineVectors(fake_imgs.detach(), image_one_hot_labels)
                    real_imgs_and_labels = combineVectors(real, image_one_hot_labels)
                    fake_pred = crit(fake_imgs_and_labels)
                    real_pred = crit(real_imgs_and_labels)
                else:
                    fake_imgs = gen(fake_noise)
                    fake_pred = crit(fake_imgs)
                    real_pred = crit(real)
                crit_loss_fake = criterion(fake_pred, torch.zeros_like(fake_pred))
                crit_loss_real = criterion(real_pred, torch.ones_like(real_pred))
                crit_loss = (crit_loss_fake + crit_loss_real) / 2
                critic_losses += [crit_loss.item()]
                crit_loss.backward(retain_graph = True)
                crit_optimizer.step()
                cur_step += 1
                
                gen_optimizer.zero_grad()
                if conditional == True:
                    fake_noise_labels_1 = combineVectors(fake_imgs, image_one_hot_labels)
                    fake_pred_1 = crit(fake_noise_labels_1)
                else:
                    fake_pred_1 = crit(fake_imgs)
                gen_loss = criterion(fake_pred_1, torch.ones_like(fake_pred_1))
                gen_loss.backward(retain_graph = True)
                gen_optimizer.step()
                generator_losses += [gen_loss.item()]
            
            if cur_step % display_step == 0 and cur_step > 0:
                gen_mean = sum(generator_losses[-display_step:]) / display_step
                crit_mean = sum(critic_losses[-display_step:]) / display_step
                print(f"Epoch {epoch}, step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
                show_tensor_images(fake_imgs, size = shape)
                show_tensor_images(real, size = shape)
                step_bins = 20
                num_examples = (len(generator_losses) // step_bins) * step_bins
                plt.plot(
                    range(num_examples // step_bins), 
                    torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                    label="Generator Loss"
                )
                plt.plot(
                    range(num_examples // step_bins), 
                    torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
                    label="Critic Loss"
                )
                plt.legend()
                plt.show()
    return gen, crit, gen_optimizer, crit_optimizer