In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [36]:
def show_tensor_images(image_tensor, num_images = 25, size = (1, 28, 28)):
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [37]:
def generate_noise(n_examples, z_dim, device = 'cpu'):
    noise = torch.randn(n_examples, z_dim, device = device)
    return noise

In [71]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

In [33]:
class Generator(nn.Module):
    def __init__(self, z_dim = 10, img_channels = 1, hidden_dims = 64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential (
            self.make_gen_block(z_dim, hidden_dims * 4),
            self.make_gen_block(hidden_dims * 4, hidden_dims * 2, kernel_size = 4, stride = 1),
            self.make_gen_block(hidden_dims * 2, hidden_dims),
            self.make_gen_block(hidden_dims, img_channels, kernel_size = 4, final_layer = True)
        )
    def make_gen_block(self, input_channels, output_channels, kernel_size = 3, stride = 2, final_layer = False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU()
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh()
            )
    def unsqueeze_noise(self, noise):
        return noise.view(len(noise), self.z_dim, 1, 1)
    def forward(self, noise):
        noise_in = self.unsqueeze_noise(noise)
        return self.gen(noise_in)

In [90]:
class Critic(nn.Module):
    def __init__(self, img_channels = 1, hidden_dims = 64):
        super(Critic, self).__init__()
        self.crit = nn.Sequential (
            self.make_crit_block(img_channels, hidden_dims),
            self.make_crit_block(hidden_dims, hidden_dims * 2),
            self.make_crit_block(hidden_dims * 2, 1, final_layer = True)
        )
    def make_crit_block(self, input_channels, output_channels, kernel_size = 4, stride = 2, final_layer = False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace = True)
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride)
            )
    def forward(self, img):
        img_ = self.crit(img)
        return img_.view(len(img_), -1)

In [83]:
def get_gradient(crit, real, fake, epsilon):
    interpolated_img = real * epsilon + fake * (1 - epsilon)
    pred = crit(interpolated_img)
    grad = torch.autograd.grad(
        inputs = interpolated_img,
        outputs = pred,
        grad_outputs=torch.ones_like(pred), 
        create_graph=True,
        retain_graph=True        
    )[0]
    return grad

In [80]:
def gradient_penalty(gradient):
    gradient = gradient.view(len(gradient), -1)
    gradient_norm = gradient.norm(2, dim=1)   
    penalty = torch.mean((gradient_norm - 1) ** 2)
    return penalty

In [66]:
def wasserstein_loss_gen(fake_pred):
    return -torch.mean(fake_pred)

In [125]:
def wasserstein_loss_crit(fake_pred, real_pred, penalty, lambda_ = 0.1):
    crit_loss = torch.mean(fake_pred) - torch.mean(real_pred) + torch.mean(lambda_ * (penalty))
    return crit_loss

In [1]:
def generate_images(generator, n_examples, z_dim, num_images = 25, size = (1, 28, 28) , device = 'cpu'):
    pred_noise = generate_noise(n_examples, z_dim, device)
    pred_images = generator(pred_noise)
    show_tensor_images(pred_images, num_images, size)

In [152]:
def GAN(training_set, epochs, display_step, batch_size, crit_repeats, learning_rate, beta_1, beta_2, lambda_, z_dim, dataloader, loss, device):
    gen = Generator(z_dim).to(device)
    gen_optimizer = torch.optim.Adam(gen.parameters(), lr = learning_rate, betas = (beta_1, beta_2))
    crit = Critic().to(device)
    crit_optimizer = torch.optim.Adam(crit.parameters(), lr = learning_rate, betas = (beta_1, beta_2))
    gen = gen.apply(weights_init)
    crit = crit.apply(weights_init)
    
    if loss == 'BCE':
        criterion = nn.BCEWithLogitsLoss()
    
    cur_step = 0
    generator_losses = []
    critic_losses = []
    for epoch in range(epochs):
        for real, _ in tqdm(dataloader):
            cur_batch_size = len(real)
            real = real.to(device)
            mean_iteration_critic_loss = 0
            if loss == 'W':
                for _ in range(crit_repeats):
                    crit_optimizer.zero_grad()
                    fake_noise = generate_noise(cur_batch_size, z_dim, device = device)
                    fake_imgs = gen(fake_noise)
                    fake_pred = crit(fake_imgs.detach())
                    real_pred = crit(real)
                    epsilon = torch.rand(len(real), 1, 1, 1, device = device, requires_grad = True)
                    grad = get_gradient(crit, real, fake_imgs.detach(), epsilon)
                    penalty = gradient_penalty(grad)
                    crit_loss = wasserstein_loss_crit(fake_pred, real_pred, penalty, lambda_)
                    mean_iteration_critic_loss += crit_loss.item() / crit_repeats
                    critic_losses += [mean_iteration_critic_loss]
                    crit_loss.backward(retain_graph = True)
                    crit_optimizer.step()
                cur_step += 1
                
                gen_optimizer.zero_grad()
                fake_noise_1 = generate_noise(cur_batch_size, z_dim, device = device)
                fake_imgs_1 = gen(fake_noise_1) 
                fake_pred_1 = crit(fake_imgs_1)
                gen_loss = wasserstein_loss_gen(fake_pred_1)
                gen_loss.backward(retain_graph = True)
                gen_optimizer.step()
                generator_losses += [gen_loss.item()]
                
            elif loss == 'BCE':
                crit_optimizer.zero_grad()
                fake_noise = generate_noise(cur_batch_size, z_dim, device = device)
                fake_imgs = gen(fake_noise)
                fake_pred = crit(fake_imgs.detach())
                real_pred = crit(real)
                crit_loss_fake = criterion(fake_pred, torch.zeros_like(fake_pred))
                crit_loss_real = criterion(real_pred, torch.ones_like(real_pred))
                crit_loss = (crit_loss_fake + crit_loss_real) / 2
                critic_losses += [crit_loss.item()]
                crit_loss.backward(retain_graph = True)
                crit_optimizer.step()
                cur_step += 1
                
                gen_optimizer.zero_grad()
                fake_noise_1 = generate_noise(cur_batch_size, z_dim, device = device)
                fake_imgs_1 = gen(fake_noise_1) 
                fake_pred_1 = crit(fake_imgs_1)
                gen_loss = criterion(fake_pred_1, torch.ones_like(fake_pred_1))
                gen_loss.backward(retain_graph = True)
                gen_optimizer.step()
                generator_losses += [gen_loss.item()]
            
            if cur_step % display_step == 0 and cur_step > 0:
                gen_mean = sum(generator_losses[-display_step:]) / display_step
                crit_mean = sum(critic_losses[-display_step:]) / display_step
                print(f"Epoch {epoch}, step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
                show_tensor_images(fake_imgs)
                show_tensor_images(real)
                step_bins = 20
                num_examples = (len(generator_losses) // step_bins) * step_bins
                plt.plot(
                    range(num_examples // step_bins), 
                    torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                    label="Generator Loss"
                )
                plt.plot(
                    range(num_examples // step_bins), 
                    torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
                    label="Critic Loss"
                )
                plt.legend()
                plt.show()
    return gen