In [1]:
import torch
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
import numpy as np
from tqdm.auto import tqdm
from torchvision.utils import make_grid

In [23]:
epochs = 300
display_step = 30
batch_size = 32
crit_repeats = 5
learning_rate = 0.002
beta_1 = 0.9
beta_2 = 0.999
lambda_ = 10
z_dim = 128
w_dim = 496
hidden_dim = 256
img_channels = 3
image_size = 1024
n_images = batch_size
size = (3, 1024, 1024)
numClasses = 0
device = 'cuda'
loss = 'W'
trained = False

In [3]:
DATA_DIR = 'FFHQ Faces'
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)

In [4]:
train_ds = ImageFolder(DATA_DIR, transform = T.Compose([
    T.Resize(image_size),
    T.CenterCrop(image_size),
    T.ToTensor(),
    T.Normalize(*stats)])
)

dataloader = DataLoader(
    train_ds, 
    batch_size, 
    shuffle = True, 
    num_workers = 2, 
    pin_memory = True
)

In [5]:
def show_tensor_images(image_tensor, num_images = 25, size = (1, 28, 28)):
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [6]:
def generate_noise(n_examples, z_dim, device = 'cpu'):
    noise = torch.randn(n_examples, z_dim, device = device)
    return noise

In [7]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.Linear):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

In [8]:
class MappingNetwork(nn.Module):
    def __init__(self, z_dim, hidden_dim, w_dim):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Linear(z_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, w_dim),
        )
    def forward(self, noise):
        out = self.seq(noise)
        return out

In [9]:
class InjectNoise(nn.Module):
    def __init__(self, im_dim):
        super().__init__()
        self.weights = nn.Parameter(
            torch.randn(im_dim)[None, :, None, None]
        )
    def forward(self, X):
        n_shape = (X.shape[0], 1, X.shape[2], X.shape[3])
        noise = torch.randn(n_shape, device = X.device)
        out = X + self.weights * noise
        return out

In [10]:
class AdaIN(nn.Module):
    def __init__(self, im_dim, w_dim):
        super().__init__()
        self.instance_norm = nn.InstanceNorm2d(im_dim)
        self.style_scale = nn.Linear(w_dim, im_dim)
        self.style_shift = nn.Linear(w_dim, im_dim)
    def forward(self, X, inter_noise):
        out_1 = self.instance_norm(X)
        s_scale = self.style_scale(inter_noise)[:, :, None, None]
        s_shift = self.style_shift(inter_noise)[:, :, None, None]
        out = s_scale * out_1 + s_shift
        return out

In [11]:
class GeneratorBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, starting_size, w_dim, use_upsample = True):
        super().__init__()
        self.use_upsample = use_upsample
        if self.use_upsample:
            self.upsample = nn.Upsample((starting_size), mode = 'bilinear')
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding = 1)
        self.inject_noise = InjectNoise(out_channels)
        self.adaIn = AdaIN(out_channels, w_dim)
        self.activation = nn.LeakyReLU(0.2)
    def forward(self, X, inter_noise):
        if self.use_upsample:
            X = self.upsample(X)
        out_1 = self.conv(X)
        out_2 = self.inject_noise(out_1)
        out_3 = self.activation(out_2)
        out = self.adaIn(out_3, inter_noise)
        return out

In [12]:
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, w_dim, z_dim, hidden_dim, alpha = 0.2, use_upsample = True):
        super().__init__()
        self.starting_constant = nn.Parameter(torch.randn(1, in_channels, 4, 4))
        self.use_upsample = use_upsample
        self.mapping_network = MappingNetwork(z_dim, hidden_dim, w_dim)
        self.block_1 = GeneratorBlock(in_channels, hidden_dim, kernel_size, 4, w_dim, use_upsample = False)
        self.block_2 = GeneratorBlock(hidden_dim, hidden_dim, kernel_size, 8, w_dim)
        self.block_3 = GeneratorBlock(hidden_dim, hidden_dim, kernel_size, 16, w_dim)
        self.block_4 = GeneratorBlock(hidden_dim, hidden_dim, kernel_size, 32, w_dim)
        self.block_5 = GeneratorBlock(hidden_dim, hidden_dim, kernel_size, 64, w_dim)
        self.block_6 = GeneratorBlock(hidden_dim, hidden_dim, kernel_size, 128, w_dim)
        self.block_7 = GeneratorBlock(hidden_dim, hidden_dim, kernel_size, 256, w_dim)
        self.block_8 = GeneratorBlock(hidden_dim, hidden_dim, kernel_size, 512, w_dim)
        self.block_9 = GeneratorBlock(hidden_dim, hidden_dim, kernel_size, 1024, w_dim)
        self.block_to_image_1 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size = 1)
        self.block_to_image_2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size = 1)
        self.block_to_image_3 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size = 1)
        self.block_to_image_4 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size = 1)
        self.block_to_image_5 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size = 1)
        self.block_to_image_6 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size = 1)
        self.block_to_image_7 = nn.Conv2d(hidden_dim, out_channels, kernel_size = 1)
        self.block_to_image_8 = nn.Conv2d(hidden_dim, out_channels, kernel_size = 1)
        self.alpha = alpha
    def upsample_to_match_size(self, smaller_image, bigger_image):
        return F.interpolate(smaller_image, size = bigger_image.shape[-2:], mode = 'bilinear')
    def forward(self, noise):
        X = self.starting_constant
        out_1 = self.mapping_network(noise)
        out_2 = self.block_1(X, out_1)
        
        out_3 = self.block_2(out_2, out_1) # 8x8
        upsampled_1 = self.block_to_image_1(out_3)
        out_4 = self.block_3(out_3, out_1) # 16x16
        upsampled_2 = self.block_to_image_2(out_4)
        upsampled_B1 = self.upsample_to_match_size(upsampled_1, upsampled_2)
        interpolated_1 = (self.alpha) * (upsampled_2) + (1 - self.alpha) * (upsampled_B1)
        
        out_5 = self.block_4(interpolated_1, out_1) # 32x32
        upsampled_3 = self.block_to_image_3(out_5)
        out_6 = self.block_5(out_5, out_1) # 64x64
        upsampled_4 = self.block_to_image_4(out_6)
        upsampled_B2 = self.upsample_to_match_size(upsampled_3, upsampled_4)
        interpolated_2 = (self.alpha) * (upsampled_4) + (1 - self.alpha) * (upsampled_B2)
        
        out_7 = self.block_6(interpolated_2, out_1) # 128x128
        upsampled_5 = self.block_to_image_5(out_7)
        out_8 = self.block_7(out_7, out_1) # 256x256
        upsampled_6 = self.block_to_image_6(out_8)
        upsampled_B3 = self.upsample_to_match_size(upsampled_5, upsampled_6)
        interpolated_3 = (self.alpha) * (upsampled_6) + (1 - self.alpha) * (upsampled_B3)
        
        out_9 = self.block_8(interpolated_3, out_1) # 512x512
        upsampled_7 = self.block_to_image_7(out_9)
        out_10 = self.block_9(out_9, out_1) # 1024x1024
        upsampled_8 = self.block_to_image_8(out_10)
        upsampled_B4 = self.upsample_to_match_size(upsampled_7, upsampled_8)
        interpolated_4 = (self.alpha) * (upsampled_8) + (1 - self.alpha) * (upsampled_B4)
        
        return interpolated_4

In [25]:
class CriticBlock(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size = 4, stride = 2):
        super().__init__()
        self.conv2d_1 = nn.Conv2d(input_channels, output_channels, kernel_size, stride)
        self.batchnorm2d_1 = nn.BatchNorm2d(output_channels)
        self.activation = nn.LeakyReLU(0.2, inplace = True)
    def forward(self, X):
        X = self.conv2d_1(X)
        X = self.batchnorm2d_1(X)
        X = self.activation(X)
        return X

In [28]:
class Critic(nn.Module):
    def __init__(self, img_channels, hidden_dims, alpha):
        super().__init__()
        self.alpha = alpha
        self.conv_1 = nn.Conv2d(img_channels, hidden_dims, kernel_size = 1)
        self.block_1 = CriticBlock(hidden_dims, hidden_dims * 2)
        self.block_2 = CriticBlock(hidden_dims * 2, hidden_dims * 4)
        self.block_3 = CriticBlock(hidden_dims * 4, hidden_dims * 8)
        self.block_4 = CriticBlock(hidden_dims * 8, hidden_dims * 16)
        self.block_5 = CriticBlock(hidden_dims * 16, hidden_dims * 32)
        self.block_6 = CriticBlock(hidden_dims * 32, hidden_dims * 64)
        self.conv_2 = nn.Conv2d(hidden_dim * 64, 1, kernel_size = 4, stride = 2)
    def forward(self, img):
        out_1 = self.conv_1(img)
        
        out_2 = self.block_1(out_1)
        out_3 = self.block_2(out_2)
        out_4 = self.block_3(out_3)
        out_5 = self.block_4(out_4)
        out_6 = self.block_5(out_5)
        out_7 = self.block_6(out_6)
        
        out_8 = self.conv_2(out_7)
        
        return out_8.view(len(out_8), -1)

In [27]:
def get_gradient(crit, real, fake, epsilon):
    interpolated_img = real * epsilon + fake * (1 - epsilon)
    pred = crit(interpolated_img)
    grad = torch.autograd.grad(
        inputs = interpolated_img,
        outputs = pred,
        grad_outputs=torch.ones_like(pred), 
        create_graph=True,
        retain_graph=True        
    )[0]
    return grad

In [16]:
def gradient_penalty(gradient):
    gradient = gradient.view(len(gradient), -1)
    gradient_norm = gradient.norm(2, dim=1)   
    penalty = torch.mean((gradient_norm - 1) ** 2)
    return penalty

In [17]:
def wasserstein_loss_gen(fake_pred):
    return -torch.mean(fake_pred)

In [18]:
def wasserstein_loss_crit(fake_pred, real_pred, penalty, lambda_ = 0.1):
    crit_loss = torch.mean(fake_pred) - torch.mean(real_pred) + torch.mean(lambda_ * (penalty))
    return crit_loss

In [19]:
def generate_images(generator, n_examples, z_dim, conditional, numClasses = 0, num_images = 25, size = (1, 28, 28), labels = [], device = 'cpu'):
    if conditional == True:
        if len(labels) == 0:
            for i in range(num_images):
                labels.append(random.randint(0, numClasses - 1))
            labels = torch.Tensor(labels).to(device)
        elif len(labels) == 1:
            labels = torch.floor(labels[0] + torch.rand(num_images, ))
            labels = labels.to(device)
        labels = labels.to(torch.int64)
        one_hot = oneHotEncode(numClasses, labels)
        pred_noise = generate_noise(n_examples, z_dim, device)
        pred_noise_labels = combineVectors(pred_noise, one_hot)
        pred_images = generator(pred_noise_labels)
    else:
        pred_noise = generate_noise(n_examples, z_dim, device)
        pred_images = generator(pred_noise)
    show_tensor_images(pred_images, num_images, size)

In [22]:
def StyleGAN(trained, generator, critic, generator_optimizer, critic_optimizer, epochs, display_step, img_channels, crit_repeats, learning_rate, beta_1, beta_2, lambda_, z_dim, w_dim, hidden_dim, shape, dataloader, loss, device):
    if trained == False:
        gen = Generator(img_channels, img_channels, 3, w_dim, z_dim, hidden_dim).to(device)
        crit = Critic(img_channels, hidden_dim, alpha = 0.2).to(device)
        gen_optimizer = torch.optim.Adam(gen.parameters(), lr = learning_rate, betas = (beta_1, beta_2))
        crit_optimizer = torch.optim.Adam(crit.parameters(), lr = learning_rate, betas = (beta_1, beta_2))
    
    elif trained == True:
        gen = generator
        crit = critic
        gen_optimizer = generator_optimizer
        crit_optimizer = critic_optimizer


    if trained == False:
        gen = gen.apply(weights_init)
        crit = crit.apply(weights_init)
    
    if loss == 'BCE':
        criterion = nn.BCEWithLogitsLoss()
    
    cur_step = 0
    generator_losses = []
    critic_losses = []
    for epoch in range(epochs):
        for real, labels in tqdm(dataloader):
            if epoch % 30 == 0 and epoch != 0 and gen.alpha != 1 and crit.alpha != 1:
                gen.alpha += 0.1
                crit.alpha += 0.1
            cur_batch_size = len(real)
            real = real.to(device)
            mean_iteration_critic_loss = 0
            if loss == 'W':
                for _ in range(crit_repeats):
                    crit_optimizer.zero_grad()
                    fake_noise = generate_noise(cur_batch_size, z_dim, device = device)
                    epsilon = torch.rand(len(real), 1, 1, 1, device = device, requires_grad = True)
                    fake_imgs = gen(fake_noise)
                    fake_pred = crit(fake_imgs)
                    real_pred = crit(real)
                    grad = get_gradient(crit, real, fake_imgs, epsilon)
                    penalty = gradient_penalty(grad)
                    crit_loss = wasserstein_loss_crit(fake_pred, real_pred, penalty, lambda_)
                    mean_iteration_critic_loss += crit_loss.item() / crit_repeats
                    critic_losses += [mean_iteration_critic_loss]
                    crit_loss.backward(retain_graph = True)
                    crit_optimizer.step()
                cur_step += 1
                
                gen_optimizer.zero_grad()
                fake_noise_1 = generate_noise(cur_batch_size, z_dim, device = device)
                fake_imgs_1 = gen(fake_noise_1)
                fake_pred_1 = crit(fake_imgs_1)
                gen_loss = wasserstein_loss_gen(fake_pred_1)
                gen_loss.backward(retain_graph = True)
                gen_optimizer.step()
                generator_losses += [gen_loss.item()]
                
            elif loss == 'BCE':
                crit_optimizer.zero_grad()
                fake_noise = generate_noise(cur_batch_size, z_dim, device = device)
                fake_imgs = gen(fake_noise)
                fake_pred = crit(fake_imgs)
                real_pred = crit(real)
                crit_loss_fake = criterion(fake_pred, torch.zeros_like(fake_pred))
                crit_loss_real = criterion(real_pred, torch.ones_like(real_pred))
                crit_loss = (crit_loss_fake + crit_loss_real) / 2
                critic_losses += [crit_loss.item()]
                crit_loss.backward(retain_graph = True)
                crit_optimizer.step()
                cur_step += 1
                
                gen_optimizer.zero_grad()
                fake_pred_1 = crit(fake_imgs)
                gen_loss = criterion(fake_pred_1, torch.ones_like(fake_pred_1))
                gen_loss.backward(retain_graph = True)
                gen_optimizer.step()
                generator_losses += [gen_loss.item()]
            
            if cur_step % display_step == 0 and cur_step > 0:
                gen_mean = sum(generator_losses[-display_step:]) / display_step
                crit_mean = sum(critic_losses[-display_step:]) / display_step
                print(f"Epoch {epoch}, step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
                show_tensor_images(fake_imgs, size = shape)
                show_tensor_images(real, size = shape)
                step_bins = 20
                num_examples = (len(generator_losses) // step_bins) * step_bins
                plt.plot(
                    range(num_examples // step_bins), 
                    torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                    label="Generator Loss"
                )
                plt.plot(
                    range(num_examples // step_bins), 
                    torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
                    label="Critic Loss"
                )
                plt.legend()
                plt.show()
    return gen, crit, gen_optimizer, crit_optimizer