In [1]:
import os
import math
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import tqdm
import numpy as np
from torch import optim
import torchvision.utils as torch_utils
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from datetime import datetime

In [2]:
# Generator and Discriminator Utilities
class WSConv2d(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1
    ):
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # initialize conv layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
    

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.leaky(self.conv2(x))
        return x
    


# Train and Eval utilities
def generate_examples(gen, steps, z_dim, n=100):
    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(1, z_dim).to(device)
            img = gen(noise, alpha, steps)
            if not os.path.exists(f'saved_examples/step{steps}'):
                os.makedirs(f'saved_examples/step{steps}')
            save_image(img*0.5+0.5, f"saved_examples/step{steps}/img_{i}.png")
    gen.train()

  
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, alpha, train_step)
 
    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

def get_loader(image_size, channels_img, batch_sizes, dataset_dir):
    transform = transforms.Compose(
        [transforms.Resize((image_size, image_size)),
         transforms.ToTensor(),
         transforms.RandomHorizontalFlip(p=0.5),
         transforms.Normalize(
            [0.5 for _ in range(channels_img)],
            [0.5 for _ in range(channels_img)],
         )
        ]
    )
    batch_size = batch_sizes[int(math.log2(image_size/4))]
    dataset = datasets.ImageFolder(root=dataset_dir, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True
    )
    return loader, dataset

def check_loader():
    loader, _ = get_loader(128)
    cloth, _  = next(iter(loader))
    _, ax     = plt.subplots(3,3,figsize=(8,8))
    plt.suptitle('Some real samples')
    ind = 0
    for k in range(3):
        for kk in range(3):
            ax[k][kk].imshow((cloth[ind].permute(1,2,0)+1)/2)
            ind +=1

In [3]:
image_factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]

# Normalization on every element of input vector
# Adapted from StyleGAN original Implementation
class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8)
    
# Implementing the Noise Mapping Network
class WSLinear(nn.Module):
    def __init__(
        self, in_features, out_features,
    ):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.scale = (2 / in_features)**0.5
        self.bias = self.linear.bias
        self.linear.bias = None

        # initialize linear layer
        nn.init.normal_(self.linear.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.linear(x * self.scale) + self.bias
    

class NoiseMappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim):
        super().__init__()
        self.noise_mapping = nn.Sequential(
            PixelNorm(),
            WSLinear(z_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
            nn.ReLU(),
            WSLinear(w_dim, w_dim),
        )

    def forward(self, x):
        return self.noise_mapping(x)
    

# Adaptive Instance Normalization (AdaIn)
class AdaIN(nn.Module):
    def __init__(self, channels, w_dim):
        super().__init__()
        self.instance_norm = nn.InstanceNorm2d(channels)
        self.style_scale = WSLinear(w_dim, channels)
        self.style_bias = WSLinear(w_dim, channels)

    def forward(self, x, w):
        x = self.instance_norm(x)
        style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
        style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
        return style_scale * x + style_bias
    

class NoiseInjectNet(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))

    def forward(self, x):
        noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
        return x + self.weight * noise

class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, w_dim):
        super(GenBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.inject_noise1 = NoiseInjectNet(out_channels)
        self.inject_noise2 = NoiseInjectNet(out_channels)
        self.adain1 = AdaIN(out_channels, w_dim)
        self.adain2 = AdaIN(out_channels, w_dim)

    def forward(self, x, w):
        x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
        x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w)
        return x

class Generator(nn.Module):
    def __init__(self, z_dim, w_dim, in_channels, img_channels=3, classes=3):
        super(Generator, self).__init__()
        self.embedding = nn.Linear(classes, 4*4)

        self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
        self.map = NoiseMappingNetwork(z_dim, w_dim)
        self.initial_adain1 = AdaIN(in_channels, w_dim)
        self.initial_adain2 = AdaIN(in_channels, w_dim)
        self.initial_noise1 = NoiseInjectNet(in_channels)
        self.initial_noise2 = NoiseInjectNet(in_channels)
        self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)

        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(len(image_factors) - 1):  # -1 to prevent index error because of factors[i+1]
            conv_in_c = int(in_channels * image_factors[i])
            conv_out_c = int(in_channels * image_factors[i + 1])
            self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
        # alpha should be scalar within [0, 1], and upscale.shape == generated.shape
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, noise, label, alpha, steps):
#         print(label.shape)
        w = self.map(noise)
        x = self.initial_adain1(self.initial_noise1(self.starting_constant), w)
        label_embedding = self.embedding(label).view(-1, 1, 4, 4)
        factor = x.shape[-1]//label_embedding.shape[-1]
        a, b, c, d = label_embedding.shape
        label_embedding = label_embedding.view(a, b, c, 1, d, 1)
        label_embedding = label_embedding.repeat(1, 1, 1, factor, 1, factor)
        label_embedding = label_embedding.reshape(a, b, x.shape[-1], x.shape[-1])
#         print(x.shape, label_embedding.shape)
#         x = torch.concat((x, label_embedding), dim = 1)
        x = x + label_embedding
        x = self.initial_conv(x)
        out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)

        if steps == 0:
            return self.initial_rgb(x)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="bilinear")
            out = self.prog_blocks[step](upscaled, w)

        # The number of channels in upscale will stay the same, while
        # out which has moved through prog_blocks might change. To ensure
        # we can convert both to rgb we use different rgb_layers
        # (steps-1) and steps for upscaled, out respectively
        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)



In [4]:
import __main__
setattr(__main__, "Generator", Generator)

In [5]:
device = "cpu"
z_dim = 512

In [6]:
def save_images(images, output_dir):
    print('Saving Images...')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    for idx, img in enumerate(images):
        img = img.permute(1, 2, 0)
        img = img.numpy()
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
        filename = f'image_{timestamp}_{idx}.png'
        filepath = os.path.join(output_dir, filename)
        plt.imsave(filepath, img)
    print("Done!")
    return True

In [7]:
def eval(classes, model_path, num_images, output_dir):
    fixed_noise = torch.randn(classes*10, z_dim, device=device)
    fixed_labels = []
    for i in range(classes):
        lab = [0 if j != i else 1 for j in range(classes)]
        lab = lab*10
        fixed_labels.append(lab)
    fixed_labels = torch.Tensor(fixed_labels).view(classes*10, classes).float().to(device)

    gen_net = torch.load(model_path, map_location=torch.device("cpu"))

    gen_net.eval()
    images = []
    print("Starting Inference Loop...")
    for _ in range(num_images):
        with torch.no_grad():
            fake = gen_net(fixed_noise, fixed_labels, 1, 4).detach().cpu()
        images.append(torch_utils.make_grid(fake, padding=2, nrow=10, normalize=True))
         
    return save_images(images, output_dir)

In [8]:
eval(10, "..\\Outputs\\StyleGAN\\models\\stylegan_mnist.pth", 5, "Plots1")
eval(5, "..\\Outputs\\StyleGAN\\models\\stylegan_flowers.pth", 5, "Plots2")
eval(3, "..\\Outputs\\StyleGAN\\models\\stylegan_shoe.pth", 5, "Plots3")

Starting Inference Loop...
Saving Images...
Done!
Starting Inference Loop...
Saving Images...
Done!
Starting Inference Loop...
Saving Images...
Done!


True